diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index 20cceb922b..f2331a92a5 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -46,7 +46,7 @@ runs: CLIENT_LIBS_TEST_IMAGE_TAG: ${{ inputs.redis-version }} run: | set -e - + echo "::group::Installing dependencies" pip install -r dev_requirements.txt pip uninstall -y redis # uninstall Redis package installed via redis-entraid @@ -57,27 +57,26 @@ runs: pip install -e ./hiredis-py else pip install "hiredis${{inputs.hiredis-version}}" - fi + fi echo "PARSER_BACKEND=$(echo "${{inputs.parser-backend}}_${{inputs.hiredis-version}}" | sed 's/[^a-zA-Z0-9]/_/g')" >> $GITHUB_ENV else echo "PARSER_BACKEND=${{inputs.parser-backend}}" >> $GITHUB_ENV fi echo "::endgroup::" - + echo "::group::Starting Redis servers" redis_major_version=$(echo "$REDIS_VERSION" | grep -oP '^\d+') echo "REDIS_MAJOR_VERSION=${redis_major_version}" >> $GITHUB_ENV - + if (( redis_major_version < 8 )); then echo "Using redis-stack for module tests" - - # Mapping of redis version to stack version + + # Mapping of redis version to stack version declare -A redis_stack_version_mapping=( - ["7.4.2"]="rs-7.4.0-v2" - ["7.2.7"]="rs-7.2.0-v14" - ["6.2.17"]="rs-6.2.6-v18" + ["7.4.4"]="rs-7.4.0-v5" + ["7.2.9"]="rs-7.2.0-v17" ) - + if [[ -v redis_stack_version_mapping[$REDIS_VERSION] ]]; then export CLIENT_LIBS_TEST_STACK_IMAGE_TAG=${redis_stack_version_mapping[$REDIS_VERSION]} echo "REDIS_MOD_URL=redis://127.0.0.1:6479/0" >> $GITHUB_ENV @@ -85,19 +84,21 @@ runs: echo "Version not found in the mapping." exit 1 fi - + if (( redis_major_version < 7 )); then export REDIS_STACK_EXTRA_ARGS="--tls-auth-clients optional --save ''" - export REDIS_EXTRA_ARGS="--tls-auth-clients optional --save ''" + export REDIS_EXTRA_ARGS="--tls-auth-clients optional --save ''" fi - + invoke devenv --endpoints=all-stack + else echo "Using redis CE for module tests" + export CLIENT_LIBS_TEST_STACK_IMAGE_TAG=$REDIS_VERSION echo "REDIS_MOD_URL=redis://127.0.0.1:6379" >> $GITHUB_ENV invoke devenv --endpoints all - fi - + fi + sleep 10 # time to settle echo "::endgroup::" shell: bash @@ -105,32 +106,32 @@ runs: - name: Run tests run: | set -e - + run_tests() { local protocol=$1 local eventloop="" - + if [ "${{inputs.event-loop}}" == "uvloop" ]; then eventloop="--uvloop" fi - + echo "::group::RESP${protocol} standalone tests" echo "REDIS_MOD_URL=${REDIS_MOD_URL}" - + if (( $REDIS_MAJOR_VERSION < 7 )) && [ "$protocol" == "3" ]; then echo "Skipping module tests: Modules doesn't support RESP3 for Redis versions < 7" invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod and not cp_integration" - else + else invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" fi - + echo "::endgroup::" - + echo "::group::RESP${protocol} cluster tests" invoke cluster-tests $eventloop --protocol=${protocol} - echo "::endgroup::" + echo "::endgroup::" } - + run_tests 2 "${{inputs.event-loop}}" run_tests 3 "${{inputs.event-loop}}" shell: bash diff --git a/.github/workflows/hiredis-py-integration.yaml b/.github/workflows/hiredis-py-integration.yaml index d81b9977b1..855e9aa8f2 100644 --- a/.github/workflows/hiredis-py-integration.yaml +++ b/.github/workflows/hiredis-py-integration.yaml @@ -23,8 +23,8 @@ env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 COVERAGE_CORE: sysmon - CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG: 'rs-7.4.0-v2' - CURRENT_REDIS_VERSION: '7.4.2' + CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG: '8.0.2' + CURRENT_REDIS_VERSION: '8.0.2' jobs: redis_version: diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index ba746189ad..5f1922ff07 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -27,8 +27,8 @@ env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 COVERAGE_CORE: sysmon - CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG: 'rs-7.4.0-v2' - CURRENT_REDIS_VERSION: '7.4.2' + CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG: '8.0.2' + CURRENT_REDIS_VERSION: '8.0.2' jobs: dependency-audit: @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: ['8.0.1-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] + redis-version: ['8.2-rc2-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.4.4', '7.2.9'] python-version: ['3.9', '3.13'] parser-backend: ['plain'] event-loop: ['asyncio'] diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml index 4d0fc338d6..81e73cd4ba 100644 --- a/.github/workflows/spellcheck.yml +++ b/.github/workflows/spellcheck.yml @@ -8,7 +8,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Check Spelling - uses: rojopolis/spellcheck-github-actions@0.48.0 + uses: rojopolis/spellcheck-github-actions@0.51.0 with: config_path: .github/spellcheck-settings.yml task_name: Markdown diff --git a/README.md b/README.md index 414a0cf79e..97afa2f9bc 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ The Python interface to the Redis key-value store. [![CI](https://github.com/redis/redis-py/workflows/CI/badge.svg?branch=master)](https://github.com/redis/redis-py/actions?query=workflow%3ACI+branch%3Amaster) -[![docs](https://readthedocs.org/projects/redis/badge/?version=stable&style=flat)](https://redis-py.readthedocs.io/en/stable/) +[![docs](https://readthedocs.org/projects/redis/badge/?version=stable&style=flat)](https://redis.readthedocs.io/en/stable/) [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) [![pypi](https://badge.fury.io/py/redis.svg)](https://pypi.org/project/redis/) [![pre-release](https://img.shields.io/github/v/release/redis/redis-py?include_prereleases&label=latest-prerelease)](https://github.com/redis/redis-py/releases) @@ -41,7 +41,7 @@ Start a redis via docker (for Redis versions < 8.0): ``` bash docker run -p 6379:6379 -it redis/redis-stack:latest - +``` To install redis-py, simply: ``` bash @@ -209,4 +209,4 @@ Special thanks to: system. - Paul Hubbard for initial packaging support. -[![Redis](./docs/_static/logo-redis.svg)](https://redis.io) \ No newline at end of file +[![Redis](./docs/_static/logo-redis.svg)](https://redis.io) diff --git a/docker-compose.yml b/docker-compose.yml index bcf85df1a7..1428e6f96b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,10 +1,10 @@ --- # image tag 8.0-RC2-pre is the one matching the 8.0 GA release x-client-libs-stack-image: &client-libs-stack-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.0-RC2-pre}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.0.2}" x-client-libs-image: &client-libs-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.0-RC2-pre}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.0.2}" services: diff --git a/docs/examples/README.md b/docs/examples/README.md index ca6d5dcfa3..89fd85d712 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -1,3 +1,3 @@ # Examples -Examples of redis-py usage go here. They're being linked to the [generated documentation](https://redis-py.readthedocs.org). +Examples of redis-py usage go here. They're being linked to the [generated documentation](https://redis.readthedocs.org). diff --git a/docs/examples/opentelemetry/README.md b/docs/examples/opentelemetry/README.md index 58085c9637..2bd1c0cafa 100644 --- a/docs/examples/opentelemetry/README.md +++ b/docs/examples/opentelemetry/README.md @@ -4,7 +4,7 @@ This example demonstrates how to monitor Redis using [OpenTelemetry](https://ope [Uptrace](https://github.com/uptrace/uptrace). It requires Docker to start Redis Server and Uptrace. See -[Monitoring redis-py performance with OpenTelemetry](https://redis-py.readthedocs.io/en/latest/opentelemetry.html) +[Monitoring redis-py performance with OpenTelemetry](https://redis.readthedocs.io/en/latest/opentelemetry.html) for details. **Step 1**. Download the example using Git: diff --git a/doctests/dt_time_series.py b/doctests/dt_time_series.py new file mode 100644 index 0000000000..98a2a923cf --- /dev/null +++ b/doctests/dt_time_series.py @@ -0,0 +1,517 @@ +# EXAMPLE: time_series_tutorial +# HIDE_START +""" +Code samples for time series page: + https://redis.io/docs/latest/develop/data-types/timeseries/ +""" + +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# REMOVE_START +r.delete( + "thermometer:1", "thermometer:2", "thermometer:3", + "rg:1", "rg:2", "rg:3", "rg:4", + "sensor3", + "wind:1", "wind:2", "wind:3", "wind:4", + "hyg:1", "hyg:compacted" +) +# REMOVE_END + +# STEP_START create +res1 = r.ts().create("thermometer:1") +print(res1) # >>> True + +res2 = r.type("thermometer:1") +print(res2) # >>> TSDB-TYPE + +res3 = r.ts().info("thermometer:1") +print(res3) +# >>> {'rules': [], ... 'total_samples': 0, ... +# STEP_END +# REMOVE_START +assert res1 is True +assert res2 == "TSDB-TYPE" +assert res3["total_samples"] == 0 +# REMOVE_END + +# STEP_START create_retention +res4 = r.ts().add("thermometer:2", 1, 10.8, retention_msecs=100) +print(res4) # >>> 1 + +res5 = r.ts().info("thermometer:2") +print(res5) +# >>> {'rules': [], ... 'retention_msecs': 100, ... +# STEP_END +# REMOVE_START +assert res4 == 1 +assert res5["retention_msecs"] == 100 +# REMOVE_END + +# STEP_START create_labels +res6 = r.ts().create( + "thermometer:3", 1, 10.4, + labels={"location": "UK", "type": "Mercury"} +) +print(res6) # >>> 1 + +res7 = r.ts().info("thermometer:3") +print(res7) +# >>> {'rules': [], ... 'labels': {'location': 'UK', 'type': 'Mercury'}, ... +# STEP_END +# REMOVE_START +assert res6 == 1 +assert res7["labels"] == {"location": "UK", "type": "Mercury"} +# REMOVE_END + +# STEP_START madd +res8 = r.ts().madd([ + ("thermometer:1", 1, 9.2), + ("thermometer:1", 2, 9.9), + ("thermometer:2", 2, 10.3) +]) +print(res8) # >>> [1, 2, 2] +# STEP_END +# REMOVE_START +assert res8 == [1, 2, 2] +# REMOVE_END + +# STEP_START get +# The last recorded temperature for thermometer:2 +# was 10.3 at time 2. +res9 = r.ts().get("thermometer:2") +print(res9) # >>> (2, 10.3) +# STEP_END +# REMOVE_START +assert res9 == (2, 10.3) +# REMOVE_END + +# STEP_START range +# Add 5 data points to a time series named "rg:1". +res10 = r.ts().create("rg:1") +print(res10) # >>> True + +res11 = r.ts().madd([ + ("rg:1", 0, 18), + ("rg:1", 1, 14), + ("rg:1", 2, 22), + ("rg:1", 3, 18), + ("rg:1", 4, 24), +]) +print(res11) # >>> [0, 1, 2, 3, 4] + +# Retrieve all the data points in ascending order. +res12 = r.ts().range("rg:1", "-", "+") +print(res12) # >>> [(0, 18.0), (1, 14.0), (2, 22.0), (3, 18.0), (4, 24.0)] + +# Retrieve data points up to time 1 (inclusive). +res13 = r.ts().range("rg:1", "-", 1) +print(res13) # >>> [(0, 18.0), (1, 14.0)] + +# Retrieve data points from time 3 onwards. +res14 = r.ts().range("rg:1", 3, "+") +print(res14) # >>> [(3, 18.0), (4, 24.0)] + +# Retrieve all the data points in descending order. +res15 = r.ts().revrange("rg:1", "-", "+") +print(res15) # >>> [(4, 24.0), (3, 18.0), (2, 22.0), (1, 14.0), (0, 18.0)] + +# Retrieve data points up to time 1 (inclusive), but return them +# in descending order. +res16 = r.ts().revrange("rg:1", "-", 1) +print(res16) # >>> [(1, 14.0), (0, 18.0)] +# STEP_END +# REMOVE_START +assert res10 is True +assert res11 == [0, 1, 2, 3, 4] +assert res12 == [(0, 18.0), (1, 14.0), (2, 22.0), (3, 18.0), (4, 24.0)] +assert res13 == [(0, 18.0), (1, 14.0)] +assert res14 == [(3, 18.0), (4, 24.0)] +assert res15 == [(4, 24.0), (3, 18.0), (2, 22.0), (1, 14.0), (0, 18.0)] +assert res16 == [(1, 14.0), (0, 18.0)] +# REMOVE_END + +# STEP_START range_filter +res17 = r.ts().range("rg:1", "-", "+", filter_by_ts=[0, 2, 4]) +print(res17) # >>> [(0, 18.0), (2, 22.0), (4, 24.0)] + +res18 = r.ts().revrange( + "rg:1", "-", "+", + filter_by_ts=[0, 2, 4], + filter_by_min_value=20, + filter_by_max_value=25, +) +print(res18) # >>> [(4, 24.0), (2, 22.0)] + +res19 = r.ts().revrange( + "rg:1", "-", "+", + filter_by_ts=[0, 2, 4], + filter_by_min_value=22, + filter_by_max_value=22, + count=1, +) +print(res19) # >>> [(2, 22.0)] +# STEP_END +# REMOVE_START +assert res17 == [(0, 18.0), (2, 22.0), (4, 24.0)] +assert res18 == [(4, 24.0), (2, 22.0)] +assert res19 == [(2, 22.0)] +# REMOVE_END + +# STEP_START query_multi +# Create three new "rg:" time series (two in the US +# and one in the UK, with different units) and add some +# data points. +res20 = r.ts().create( + "rg:2", + labels={"location": "us", "unit": "cm"}, +) +print(res20) # >>> True + +res21 = r.ts().create( + "rg:3", + labels={"location": "us", "unit": "in"}, +) +print(res21) # >>> True + +res22 = r.ts().create( + "rg:4", + labels={"location": "uk", "unit": "mm"}, +) +print(res22) # >>> True + +res23 = r.ts().madd([ + ("rg:2", 0, 1.8), + ("rg:3", 0, 0.9), + ("rg:4", 0, 25), +]) +print(res23) # >>> [0, 0, 0] + +res24 = r.ts().madd([ + ("rg:2", 1, 2.1), + ("rg:3", 1, 0.77), + ("rg:4", 1, 18), +]) +print(res24) # >>> [1, 1, 1] + +res25 = r.ts().madd([ + ("rg:2", 2, 2.3), + ("rg:3", 2, 1.1), + ("rg:4", 2, 21), +]) +print(res25) # >>> [2, 2, 2] + +res26 = r.ts().madd([ + ("rg:2", 3, 1.9), + ("rg:3", 3, 0.81), + ("rg:4", 3, 19), +]) +print(res26) # >>> [3, 3, 3] + +res27 = r.ts().madd([ + ("rg:2", 4, 1.78), + ("rg:3", 4, 0.74), + ("rg:4", 4, 23), +]) +print(res27) # >>> [4, 4, 4] + +# Retrieve the last data point from each US time series. If +# you don't specify any labels, an empty array is returned +# for the labels. +res28 = r.ts().mget(["location=us"]) +print(res28) # >>> [{'rg:2': [{}, 4, 1.78]}, {'rg:3': [{}, 4, 0.74]}] + +# Retrieve the same data points, but include the `unit` +# label in the results. +res29 = r.ts().mget(["location=us"], select_labels=["unit"]) +print(res29) # >>> [{'unit': 'cm'}, (4, 1.78), {'unit': 'in'}, (4, 0.74)] + +# Retrieve data points up to time 2 (inclusive) from all +# time series that use millimeters as the unit. Include all +# labels in the results. +res30 = r.ts().mrange( + "-", 2, filters=["unit=mm"], with_labels=True +) +print(res30) +# >>> [{'rg:4': [{'location': 'uk', 'unit': 'mm'}, [(0, 25.4),... + +# Retrieve data points from time 1 to time 3 (inclusive) from +# all time series that use centimeters or millimeters as the unit, +# but only return the `location` label. Return the results +# in descending order of timestamp. +res31 = r.ts().mrevrange( + 1, 3, filters=["unit=(cm,mm)"], select_labels=["location"] +) +print(res31) +# >>> [[{'location': 'uk'}, (3, 19.0), (2, 21.0), (1, 18.0)],... +# STEP_END +# REMOVE_START +assert res20 is True +assert res21 is True +assert res22 is True +assert res23 == [0, 0, 0] +assert res24 == [1, 1, 1] +assert res25 == [2, 2, 2] +assert res26 == [3, 3, 3] +assert res27 == [4, 4, 4] +assert res28 == [{'rg:2': [{}, 4, 1.78]}, {'rg:3': [{}, 4, 0.74]}] +assert res29 == [ + {'rg:2': [{'unit': 'cm'}, 4, 1.78]}, + {'rg:3': [{'unit': 'in'}, 4, 0.74]} +] +assert res30 == [ + { + 'rg:4': [ + {'location': 'uk', 'unit': 'mm'}, + [(0, 25), (1, 18.0), (2, 21.0)] + ] + } +] +assert res31 == [ + {'rg:2': [{'location': 'us'}, [(3, 1.9), (2, 2.3), (1, 2.1)]]}, + {'rg:4': [{'location': 'uk'}, [(3, 19.0), (2, 21.0), (1, 18.0)]]} +] +# REMOVE_END + +# STEP_START agg +res32 = r.ts().range( + "rg:2", "-", "+", + aggregation_type="avg", + bucket_size_msec=2 +) +print(res32) +# >>> [(0, 1.9500000000000002), (2, 2.0999999999999996), (4, 1.78)] +# STEP_END +# REMOVE_START +assert res32 == [ + (0, 1.9500000000000002), (2, 2.0999999999999996), + (4, 1.78) +] +# REMOVE_END + +# STEP_START agg_bucket +res33 = r.ts().create("sensor3") +print(res33) # >>> True + +res34 = r.ts().madd([ + ("sensor3", 10, 1000), + ("sensor3", 20, 2000), + ("sensor3", 30, 3000), + ("sensor3", 40, 4000), + ("sensor3", 50, 5000), + ("sensor3", 60, 6000), + ("sensor3", 70, 7000), +]) +print(res34) # >>> [10, 20, 30, 40, 50, 60, 70] + +res35 = r.ts().range( + "sensor3", 10, 70, + aggregation_type="min", + bucket_size_msec=25 +) +print(res35) +# >>> [(0, 1000.0), (25, 3000.0), (50, 5000.0)] +# STEP_END +# REMOVE_START +assert res33 is True +assert res34 == [10, 20, 30, 40, 50, 60, 70] +assert res35 == [(0, 1000.0), (25, 3000.0), (50, 5000.0)] +# REMOVE_END + +# STEP_START agg_align +res36 = r.ts().range( + "sensor3", 10, 70, + aggregation_type="min", + bucket_size_msec=25, + align="START" +) +print(res36) +# >>> [(10, 1000.0), (35, 4000.0), (60, 6000.0)] +# STEP_END +# REMOVE_START +assert res36 == [(10, 1000.0), (35, 4000.0), (60, 6000.0)] +# REMOVE_END + +# STEP_START agg_multi +res37 = r.ts().create( + "wind:1", + labels={"country": "uk"} +) +print(res37) # >>> True + +res38 = r.ts().create( + "wind:2", + labels={"country": "uk"} +) +print(res38) # >>> True + +res39 = r.ts().create( + "wind:3", + labels={"country": "us"} +) +print(res39) # >>> True + +res40 = r.ts().create( + "wind:4", + labels={"country": "us"} +) +print(res40) # >>> True + +res41 = r.ts().madd([ + ("wind:1", 1, 12), + ("wind:2", 1, 18), + ("wind:3", 1, 5), + ("wind:4", 1, 20), +]) +print(res41) # >>> [1, 1, 1, 1] + +res42 = r.ts().madd([ + ("wind:1", 2, 14), + ("wind:2", 2, 21), + ("wind:3", 2, 4), + ("wind:4", 2, 25), +]) +print(res42) # >>> [2, 2, 2, 2] + +res43 = r.ts().madd([ + ("wind:1", 3, 10), + ("wind:2", 3, 24), + ("wind:3", 3, 8), + ("wind:4", 3, 18), +]) +print(res43) # >>> [3, 3, 3, 3] + +# The result pairs contain the timestamp and the maximum sample value +# for the country at that timestamp. +res44 = r.ts().mrange( + "-", "+", + filters=["country=(us,uk)"], + groupby="country", + reduce="max" +) +print(res44) +# >>> [{'country=uk': [{}, [(1, 18.0), (2, 21.0), (3, 24.0)]]}, ... + +# The result pairs contain the timestamp and the average sample value +# for the country at that timestamp. +res45 = r.ts().mrange( + "-", "+", + filters=["country=(us,uk)"], + groupby="country", + reduce="avg" +) +print(res45) +# >>> [{'country=uk': [{}, [(1, 15.0), (2, 17.5), (3, 17.0)]]}, ... +# STEP_END +# REMOVE_START +assert res37 is True +assert res38 is True +assert res39 is True +assert res40 is True +assert res41 == [1, 1, 1, 1] +assert res42 == [2, 2, 2, 2] +assert res43 == [3, 3, 3, 3] +assert res44 == [ + {'country=uk': [{}, [(1, 18.0), (2, 21.0), (3, 24.0)]]}, + {'country=us': [{}, [(1, 20.0), (2, 25.0), (3, 18.0)]]} +] +assert res45 == [ + {'country=uk': [{}, [(1, 15.0), (2, 17.5), (3, 17.0)]]}, + {'country=us': [{}, [(1, 12.5), (2, 14.5), (3, 13.0)]]} +] +# REMOVE_END + +# STEP_START create_compaction +res45 = r.ts().create("hyg:1") +print(res45) # >>> True + +res46 = r.ts().create("hyg:compacted") +print(res46) # >>> True + +res47 = r.ts().createrule("hyg:1", "hyg:compacted", "min", 3) +print(res47) # >>> True + +res48 = r.ts().info("hyg:1") +print(res48.rules) +# >>> [['hyg:compacted', 3, 'MIN', 0]] + +res49 = r.ts().info("hyg:compacted") +print(res49.source_key) # >>> 'hyg:1' +# STEP_END +# REMOVE_START +assert res45 is True +assert res46 is True +assert res47 is True +assert res48.rules == [['hyg:compacted', 3, 'MIN', 0]] +assert res49.source_key == 'hyg:1' +# REMOVE_END + +# STEP_START comp_add +res50 = r.ts().madd([ + ("hyg:1", 0, 75), + ("hyg:1", 1, 77), + ("hyg:1", 2, 78), +]) +print(res50) # >>> [0, 1, 2] + +res51 = r.ts().range("hyg:compacted", "-", "+") +print(res51) # >>> [] + +res52 = r.ts().add("hyg:1", 3, 79) +print(res52) # >>> 3 + +res53 = r.ts().range("hyg:compacted", "-", "+") +print(res53) # >>> [(0, 75.0)] +# STEP_END +# REMOVE_START +assert res50 == [0, 1, 2] +assert res51 == [] +assert res52 == 3 +assert res53 == [(0, 75.0)] +# REMOVE_END + +# STEP_START del +res54 = r.ts().info("thermometer:1") +print(res54.total_samples) # >>> 2 +print(res54.first_timestamp) # >>> 1 +print(res54.last_timestamp) # >>> 2 + +res55 = r.ts().add("thermometer:1", 3, 9.7) +print(res55) # >>> 3 + +res56 = r.ts().info("thermometer:1") +print(res56.total_samples) # >>> 3 +print(res56.first_timestamp) # >>> 1 +print(res56.last_timestamp) # >>> 3 + +res57 = r.ts().delete("thermometer:1", 1, 2) +print(res57) # >>> 2 + +res58 = r.ts().info("thermometer:1") +print(res58.total_samples) # >>> 1 +print(res58.first_timestamp) # >>> 3 +print(res58.last_timestamp) # >>> 3 + +res59 = r.ts().delete("thermometer:1", 3, 3) +print(res59) # >>> 1 + +res60 = r.ts().info("thermometer:1") +print(res60.total_samples) # >>> 0 +# STEP_END +# REMOVE_START +assert res54.total_samples == 2 +assert res54.first_timestamp == 1 +assert res54.last_timestamp == 2 +assert res55 == 3 +assert res56.total_samples == 3 +assert res56.first_timestamp == 1 +assert res56.last_timestamp == 3 +assert res57 == 2 +assert res58.total_samples == 1 +assert res58.first_timestamp == 3 +assert res58.last_timestamp == 3 +assert res59 == 1 +assert res60.total_samples == 0 +# REMOVE_END diff --git a/doctests/home_prob_dts.py b/doctests/home_prob_dts.py new file mode 100644 index 0000000000..39d516242f --- /dev/null +++ b/doctests/home_prob_dts.py @@ -0,0 +1,232 @@ +# EXAMPLE: home_prob_dts +""" +Probabilistic data type examples: + https://redis.io/docs/latest/develop/connect/clients/python/redis-py/prob +""" + +# HIDE_START +import redis +r = redis.Redis(decode_responses=True) +# HIDE_END +# REMOVE_START +r.delete( + "recorded_users", "other_users", + "group:1", "group:2", "both_groups", + "items_sold", + "male_heights", "female_heights", "all_heights", + "top_3_songs" +) +# REMOVE_END + +# STEP_START bloom +res1 = r.bf().madd("recorded_users", "andy", "cameron", "david", "michelle") +print(res1) # >>> [1, 1, 1, 1] + +res2 = r.bf().exists("recorded_users", "cameron") +print(res2) # >>> 1 + +res3 = r.bf().exists("recorded_users", "kaitlyn") +print(res3) # >>> 0 +# STEP_END +# REMOVE_START +assert res1 == [1, 1, 1, 1] +assert res2 == 1 +assert res3 == 0 +# REMOVE_END + +# STEP_START cuckoo +res4 = r.cf().add("other_users", "paolo") +print(res4) # >>> 1 + +res5 = r.cf().add("other_users", "kaitlyn") +print(res5) # >>> 1 + +res6 = r.cf().add("other_users", "rachel") +print(res6) # >>> 1 + +res7 = r.cf().mexists("other_users", "paolo", "rachel", "andy") +print(res7) # >>> [1, 1, 0] + +res8 = r.cf().delete("other_users", "paolo") +print(res8) # >>> 1 + +res9 = r.cf().exists("other_users", "paolo") +print(res9) # >>> 0 +# STEP_END +# REMOVE_START +assert res4 == 1 +assert res5 == 1 +assert res6 == 1 +assert res7 == [1, 1, 0] +assert res8 == 1 +assert res9 == 0 +# REMOVE_END + +# STEP_START hyperloglog +res10 = r.pfadd("group:1", "andy", "cameron", "david") +print(res10) # >>> 1 + +res11 = r.pfcount("group:1") +print(res11) # >>> 3 + +res12 = r.pfadd("group:2", "kaitlyn", "michelle", "paolo", "rachel") +print(res12) # >>> 1 + +res13 = r.pfcount("group:2") +print(res13) # >>> 4 + +res14 = r.pfmerge("both_groups", "group:1", "group:2") +print(res14) # >>> True + +res15 = r.pfcount("both_groups") +print(res15) # >>> 7 +# STEP_END +# REMOVE_START +assert res10 == 1 +assert res11 == 3 +assert res12 == 1 +assert res13 == 4 +assert res14 +assert res15 == 7 +# REMOVE_END + +# STEP_START cms +# Specify that you want to keep the counts within 0.01 +# (1%) of the true value with a 0.005 (0.5%) chance +# of going outside this limit. +res16 = r.cms().initbyprob("items_sold", 0.01, 0.005) +print(res16) # >>> True + +# The parameters for `incrby()` are two lists. The count +# for each item in the first list is incremented by the +# value at the same index in the second list. +res17 = r.cms().incrby( + "items_sold", + ["bread", "tea", "coffee", "beer"], # Items sold + [300, 200, 200, 100] +) +print(res17) # >>> [300, 200, 200, 100] + +res18 = r.cms().incrby( + "items_sold", + ["bread", "coffee"], + [100, 150] +) +print(res18) # >>> [400, 350] + +res19 = r.cms().query("items_sold", "bread", "tea", "coffee", "beer") +print(res19) # >>> [400, 200, 350, 100] +# STEP_END +# REMOVE_START +assert res16 +assert res17 == [300, 200, 200, 100] +assert res18 == [400, 350] +assert res19 == [400, 200, 350, 100] +# REMOVE_END + +# STEP_START tdigest +res20 = r.tdigest().create("male_heights") +print(res20) # >>> True + +res21 = r.tdigest().add( + "male_heights", + [175.5, 181, 160.8, 152, 177, 196, 164] +) +print(res21) # >>> OK + +res22 = r.tdigest().min("male_heights") +print(res22) # >>> 152.0 + +res23 = r.tdigest().max("male_heights") +print(res23) # >>> 196.0 + +res24 = r.tdigest().quantile("male_heights", 0.75) +print(res24) # >>> 181 + +# Note that the CDF value for 181 is not exactly +# 0.75. Both values are estimates. +res25 = r.tdigest().cdf("male_heights", 181) +print(res25) # >>> [0.7857142857142857] + +res26 = r.tdigest().create("female_heights") +print(res26) # >>> True + +res27 = r.tdigest().add( + "female_heights", + [155.5, 161, 168.5, 170, 157.5, 163, 171] +) +print(res27) # >>> OK + +res28 = r.tdigest().quantile("female_heights", 0.75) +print(res28) # >>> [170] + +res29 = r.tdigest().merge( + "all_heights", 2, "male_heights", "female_heights" +) +print(res29) # >>> OK + +res30 = r.tdigest().quantile("all_heights", 0.75) +print(res30) # >>> [175.5] +# STEP_END +# REMOVE_START +assert res20 +assert res21 == "OK" +assert res22 == 152.0 +assert res23 == 196.0 +assert res24 == [181] +assert res25 == [0.7857142857142857] +assert res26 +assert res27 == "OK" +assert res28 == [170] +assert res29 == "OK" +assert res30 == [175.5] +# REMOVE_END + +# STEP_START topk +# The `reserve()` method creates the Top-K object with +# the given key. The parameters are the number of items +# in the ranking and values for `width`, `depth`, and +# `decay`, described in the Top-K reference page. +res31 = r.topk().reserve("top_3_songs", 3, 7, 8, 0.9) +print(res31) # >>> True + +# The parameters for `incrby()` are two lists. The count +# for each item in the first list is incremented by the +# value at the same index in the second list. +res32 = r.topk().incrby( + "top_3_songs", + [ + "Starfish Trooper", + "Only one more time", + "Rock me, Handel", + "How will anyone know?", + "Average lover", + "Road to everywhere" + ], + [ + 3000, + 1850, + 1325, + 3890, + 4098, + 770 + ] +) +print(res32) +# >>> [None, None, None, 'Rock me, Handel', 'Only one more time', None] + +res33 = r.topk().list("top_3_songs") +print(res33) +# >>> ['Average lover', 'How will anyone know?', 'Starfish Trooper'] + +res34 = r.topk().query( + "top_3_songs", "Starfish Trooper", "Road to everywhere" +) +print(res34) # >>> [1, 0] +# STEP_END +# REMOVE_START +assert res31 +assert res32 == [None, None, None, 'Rock me, Handel', 'Only one more time', None] +assert res33 == ['Average lover', 'How will anyone know?', 'Starfish Trooper'] +assert res34 == [1, 0] +# REMOVE_END diff --git a/redis/__init__.py b/redis/__init__.py index 2782e68fc2..3ac3168a3a 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -20,6 +20,7 @@ DataError, InvalidPipelineStack, InvalidResponse, + MaxConnectionsError, OutOfMemoryError, PubSubError, ReadOnlyError, @@ -46,7 +47,7 @@ def int_or_str(value): # This version is used when building the package for publishing -__version__ = "6.2.0" +__version__ = "6.3.0" VERSION = tuple(map(int_or_str, __version__.split("."))) @@ -66,6 +67,7 @@ def int_or_str(value): "default_backoff", "InvalidPipelineStack", "InvalidResponse", + "MaxConnectionsError", "OutOfMemoryError", "PubSubError", "ReadOnlyError", diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py index 5468addf62..154dc66dfb 100644 --- a/redis/_parsers/helpers.py +++ b/redis/_parsers/helpers.py @@ -676,7 +676,8 @@ def parse_client_info(value): "omem", "tot-mem", }: - client_info[int_key] = int(client_info[int_key]) + if int_key in client_info: + client_info[int_key] = int(client_info[int_key]) return client_info diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 1f34e03011..7e5404f6c8 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -814,7 +814,13 @@ async def _execute_command( moved = False return await target_node.execute_command(*args, **kwargs) - except (BusyLoadingError, MaxConnectionsError): + except BusyLoadingError: + raise + except MaxConnectionsError: + # MaxConnectionsError indicates client-side resource exhaustion + # (too many connections in the pool), not a node failure. + # Don't treat this as a node failure - just re-raise the error + # without reinitializing the cluster. raise except (ConnectionError, TimeoutError): # Connection retries are being handled in the node's @@ -2350,10 +2356,11 @@ async def reset(self): # watching something if self._transaction_connection: try: - # call this manually since our unwatch or - # immediate_execute_command methods can call reset() - await self._transaction_connection.send_command("UNWATCH") - await self._transaction_connection.read_response() + if self._watching: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + await self._transaction_connection.send_command("UNWATCH") + await self._transaction_connection.read_response() # we can safely return the connection to the pool here since we're # sure we're no longer WATCHing anything self._transaction_node.release(self._transaction_connection) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 326daaa8f8..4efd868f6f 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -295,13 +295,18 @@ async def connect(self): """Connects to the Redis server if not already connected""" await self.connect_check_health(check_health=True) - async def connect_check_health(self, check_health: bool = True): + async def connect_check_health( + self, check_health: bool = True, retry_socket_connect: bool = True + ): if self.is_connected: return try: - await self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect() - ) + if retry_socket_connect: + await self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect() + ) + else: + await self._connect() except asyncio.CancelledError: raise # in 3.7 and earlier, this is an Exception, not BaseException except (socket.timeout, asyncio.TimeoutError): @@ -1037,6 +1042,7 @@ class ConnectionPool: By default, TCP connections are created unless ``connection_class`` is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for unix sockets. + :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. Any additional keyword arguments are passed to the constructor of ``connection_class``. @@ -1112,9 +1118,11 @@ def __init__( self._event_dispatcher = EventDispatcher() def __repr__(self): + conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) return ( f"<{self.__class__.__module__}.{self.__class__.__name__}" - f"({self.connection_class(**self.connection_kwargs)!r})>" + f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" + f"({conn_kwargs})>)>" ) def reset(self): diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py index a20f8b4849..98b2d9c6f8 100644 --- a/redis/asyncio/retry.py +++ b/redis/asyncio/retry.py @@ -2,18 +2,16 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, RedisError, TimeoutError - -if TYPE_CHECKING: - from redis.backoff import AbstractBackoff - +from redis.retry import AbstractRetry T = TypeVar("T") +if TYPE_CHECKING: + from redis.backoff import AbstractBackoff -class Retry: - """Retry a specific number of times after a failure""" - __slots__ = "_backoff", "_retries", "_supported_errors" +class Retry(AbstractRetry[RedisError]): + __hash__ = AbstractRetry.__hash__ def __init__( self, @@ -24,36 +22,17 @@ def __init__( TimeoutError, ), ): - """ - Initialize a `Retry` object with a `Backoff` object - that retries a maximum of `retries` times. - `retries` can be negative to retry forever. - You can specify the types of supported errors which trigger - a retry with the `supported_errors` parameter. - """ - self._backoff = backoff - self._retries = retries - self._supported_errors = supported_errors + super().__init__(backoff, retries, supported_errors) - def update_supported_errors(self, specified_errors: list): - """ - Updates the supported errors with the specified error types - """ - self._supported_errors = tuple( - set(self._supported_errors + tuple(specified_errors)) - ) - - def get_retries(self) -> int: - """ - Get the number of retries. - """ - return self._retries + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Retry): + return NotImplemented - def update_retries(self, value: int) -> None: - """ - Set the number of retries. - """ - self._retries = value + return ( + self._backoff == other._backoff + and self._retries == other._retries + and set(self._supported_errors) == set(other._supported_errors) + ) async def call_with_retry( self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index fae6875d82..d0455ab6eb 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -11,8 +11,12 @@ SSLConnection, ) from redis.commands import AsyncSentinelCommands -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + ResponseError, + TimeoutError, +) class MasterNotFoundError(ConnectionError): @@ -37,11 +41,10 @@ def __repr__(self): async def connect_to(self, address): self.host, self.port = address - await super().connect() - if self.connection_pool.check_connection: - await self.send_command("PING") - if str_if_bytes(await self.read_response()) != "PONG": - raise ConnectionError("PING failed") + await self.connect_check_health( + check_health=self.connection_pool.check_connection, + retry_socket_connect=False, + ) async def _connect_retry(self): if self._reader: @@ -223,19 +226,31 @@ async def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - once = bool(kwargs.get("once", False)) - if "once" in kwargs.keys(): - kwargs.pop("once") + once = bool(kwargs.pop("once", False)) + + # Check if command is supposed to return the original + # responses instead of boolean value. + return_responses = bool(kwargs.pop("return_responses", False)) if once: - await random.choice(self.sentinels).execute_command(*args, **kwargs) - else: - tasks = [ - asyncio.Task(sentinel.execute_command(*args, **kwargs)) - for sentinel in self.sentinels - ] - await asyncio.gather(*tasks) - return True + response = await random.choice(self.sentinels).execute_command( + *args, **kwargs + ) + if return_responses: + return [response] + else: + return True if response else False + + tasks = [ + asyncio.Task(sentinel.execute_command(*args, **kwargs)) + for sentinel in self.sentinels + ] + responses = await asyncio.gather(*tasks) + + if return_responses: + return responses + + return all(responses) def __repr__(self): sentinel_addresses = [] diff --git a/redis/backoff.py b/redis/backoff.py index 22a3ed0abb..6e1f68a7ba 100644 --- a/redis/backoff.py +++ b/redis/backoff.py @@ -170,7 +170,7 @@ def __hash__(self) -> int: return hash((self._base, self._cap)) def __eq__(self, other) -> bool: - if not isinstance(other, EqualJitterBackoff): + if not isinstance(other, ExponentialWithJitterBackoff): return NotImplemented return self._base == other._base and self._cap == other._cap diff --git a/redis/client.py b/redis/client.py index ea7f3d84de..28e9a82f76 100755 --- a/redis/client.py +++ b/redis/client.py @@ -450,7 +450,7 @@ def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline": def transaction( self, func: Callable[["Pipeline"], None], *watches, **kwargs - ) -> None: + ) -> Union[List[Any], Any, None]: """ Convenience method for executing the callable `func` as a transaction while watching all keys specified in `watches`. The 'func' callable diff --git a/redis/cluster.py b/redis/cluster.py index baa85ae122..4b971cf86d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -39,6 +39,7 @@ DataError, ExecAbortError, InvalidPipelineStack, + MaxConnectionsError, MovedError, RedisClusterException, RedisError, @@ -856,7 +857,6 @@ def pipeline(self, transaction=None, shard_hint=None): startup_nodes=self.nodes_manager.startup_nodes, result_callbacks=self.result_callbacks, cluster_response_callbacks=self.cluster_response_callbacks, - cluster_error_retry_attempts=self.retry.get_retries(), read_from_replicas=self.read_from_replicas, load_balancing_strategy=self.load_balancing_strategy, reinitialize_steps=self.reinitialize_steps, @@ -1236,6 +1236,12 @@ def _execute_command(self, target_node, *args, **kwargs): return response except AuthenticationError: raise + except MaxConnectionsError: + # MaxConnectionsError indicates client-side resource exhaustion + # (too many connections in the pool), not a node failure. + # Don't treat this as a node failure - just re-raise the error + # without reinitializing the cluster. + raise except (ConnectionError, TimeoutError) as e: # ConnectionError can also be raised if we couldn't get a # connection from the pool before timing out, so check that @@ -3290,10 +3296,11 @@ def reset(self): # watching something if self._transaction_connection: try: - # call this manually since our unwatch or - # immediate_execute_command methods can call reset() - self._transaction_connection.send_command("UNWATCH") - self._transaction_connection.read_response() + if self._watching: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + self._transaction_connection.send_command("UNWATCH") + self._transaction_connection.read_response() # we can safely return the connection to the pool here since we're # sure we're no longer WATCHing anything node = self._nodes_manager.find_connection_owner( diff --git a/redis/commands/core.py b/redis/commands/core.py index 378898272f..d6fb550724 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -3290,7 +3290,7 @@ class SetCommands(CommandsProtocol): see: https://redis.io/topics/data-types#sets """ - def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def sadd(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: """ Add ``value(s)`` to set ``name`` @@ -3298,7 +3298,7 @@ def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("SADD", name, *values) - def scard(self, name: str) -> Union[Awaitable[int], int]: + def scard(self, name: KeyT) -> Union[Awaitable[int], int]: """ Return the number of elements in set ``name`` @@ -3337,7 +3337,7 @@ def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: return self.execute_command("SINTER", *args, keys=args) def sintercard( - self, numkeys: int, keys: List[str], limit: int = 0 + self, numkeys: int, keys: List[KeyT], limit: int = 0 ) -> Union[Awaitable[int], int]: """ Return the cardinality of the intersect of multiple sets specified by ``keys``. @@ -3352,7 +3352,7 @@ def sintercard( return self.execute_command("SINTERCARD", *args, keys=keys) def sinterstore( - self, dest: str, keys: List, *args: List + self, dest: KeyT, keys: List, *args: List ) -> Union[Awaitable[int], int]: """ Store the intersection of sets specified by ``keys`` into a new @@ -3364,7 +3364,7 @@ def sinterstore( return self.execute_command("SINTERSTORE", dest, *args) def sismember( - self, name: str, value: str + self, name: KeyT, value: str ) -> Union[Awaitable[Union[Literal[0], Literal[1]]], Union[Literal[0], Literal[1]]]: """ Return whether ``value`` is a member of set ``name``: @@ -3375,7 +3375,7 @@ def sismember( """ return self.execute_command("SISMEMBER", name, value, keys=[name]) - def smembers(self, name: str) -> Union[Awaitable[Set], Set]: + def smembers(self, name: KeyT) -> Union[Awaitable[Set], Set]: """ Return all members of the set ``name`` @@ -3384,7 +3384,7 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: return self.execute_command("SMEMBERS", name, keys=[name]) def smismember( - self, name: str, values: List, *args: List + self, name: KeyT, values: List, *args: List ) -> Union[ Awaitable[List[Union[Literal[0], Literal[1]]]], List[Union[Literal[0], Literal[1]]], @@ -3400,7 +3400,7 @@ def smismember( args = list_or_args(values, args) return self.execute_command("SMISMEMBER", name, *args, keys=[name]) - def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: + def smove(self, src: KeyT, dst: KeyT, value: str) -> Union[Awaitable[bool], bool]: """ Move ``value`` from set ``src`` to set ``dst`` atomically @@ -3408,7 +3408,7 @@ def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: """ return self.execute_command("SMOVE", src, dst, value) - def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None]: + def spop(self, name: KeyT, count: Optional[int] = None) -> Union[str, List, None]: """ Remove and return a random member of set ``name`` @@ -3418,7 +3418,7 @@ def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None] return self.execute_command("SPOP", name, *args) def srandmember( - self, name: str, number: Optional[int] = None + self, name: KeyT, number: Optional[int] = None ) -> Union[str, List, None]: """ If ``number`` is None, returns a random member of set ``name``. @@ -3432,7 +3432,7 @@ def srandmember( args = (number is not None) and [number] or [] return self.execute_command("SRANDMEMBER", name, *args) - def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def srem(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: """ Remove ``values`` from set ``name`` @@ -3450,7 +3450,7 @@ def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: return self.execute_command("SUNION", *args, keys=args) def sunionstore( - self, dest: str, keys: List, *args: List + self, dest: KeyT, keys: List, *args: List ) -> Union[Awaitable[int], int]: """ Store the union of sets specified by ``keys`` into a new @@ -3484,6 +3484,28 @@ def xack(self, name: KeyT, groupname: GroupT, *ids: StreamIdT) -> ResponseT: """ return self.execute_command("XACK", name, groupname, *ids) + def xackdel( + self, + name: KeyT, + groupname: GroupT, + *ids: StreamIdT, + ref_policy: Literal["KEEPREF", "DELREF", "ACKED"] = "KEEPREF", + ) -> ResponseT: + """ + Combines the functionality of XACK and XDEL. Acknowledges the specified + message IDs in the given consumer group and simultaneously attempts to + delete the corresponding entries from the stream. + """ + if not ids: + raise DataError("XACKDEL requires at least one message ID") + + if ref_policy not in {"KEEPREF", "DELREF", "ACKED"}: + raise DataError("XACKDEL ref_policy must be one of: KEEPREF, DELREF, ACKED") + + pieces = [name, groupname, ref_policy, "IDS", len(ids)] + pieces.extend(ids) + return self.execute_command("XACKDEL", *pieces) + def xadd( self, name: KeyT, @@ -3494,6 +3516,7 @@ def xadd( nomkstream: bool = False, minid: Union[StreamIdT, None] = None, limit: Optional[int] = None, + ref_policy: Optional[Literal["KEEPREF", "DELREF", "ACKED"]] = None, ) -> ResponseT: """ Add to a stream. @@ -3507,6 +3530,10 @@ def xadd( minid: the minimum id in the stream to query. Can't be specified with maxlen. limit: specifies the maximum number of entries to retrieve + ref_policy: optional reference policy for consumer groups when trimming: + - KEEPREF (default): When trimming, preserves references in consumer groups' PEL + - DELREF: When trimming, removes all references from consumer groups' PEL + - ACKED: When trimming, only removes entries acknowledged by all consumer groups For more information see https://redis.io/commands/xadd """ @@ -3514,6 +3541,9 @@ def xadd( if maxlen is not None and minid is not None: raise DataError("Only one of ```maxlen``` or ```minid``` may be specified") + if ref_policy is not None and ref_policy not in {"KEEPREF", "DELREF", "ACKED"}: + raise DataError("XADD ref_policy must be one of: KEEPREF, DELREF, ACKED") + if maxlen is not None: if not isinstance(maxlen, int) or maxlen < 0: raise DataError("XADD maxlen must be non-negative integer") @@ -3530,6 +3560,8 @@ def xadd( pieces.extend([b"LIMIT", limit]) if nomkstream: pieces.append(b"NOMKSTREAM") + if ref_policy is not None: + pieces.append(ref_policy) pieces.append(id) if not isinstance(fields, dict) or len(fields) == 0: raise DataError("XADD fields must be a non-empty dict") @@ -3683,6 +3715,26 @@ def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: """ return self.execute_command("XDEL", name, *ids) + def xdelex( + self, + name: KeyT, + *ids: StreamIdT, + ref_policy: Literal["KEEPREF", "DELREF", "ACKED"] = "KEEPREF", + ) -> ResponseT: + """ + Extended version of XDEL that provides more control over how message entries + are deleted concerning consumer groups. + """ + if not ids: + raise DataError("XDELEX requires at least one message ID") + + if ref_policy not in {"KEEPREF", "DELREF", "ACKED"}: + raise DataError("XDELEX ref_policy must be one of: KEEPREF, DELREF, ACKED") + + pieces = [name, ref_policy, "IDS", len(ids)] + pieces.extend(ids) + return self.execute_command("XDELEX", *pieces) + def xgroup_create( self, name: KeyT, @@ -4034,6 +4086,7 @@ def xtrim( approximate: bool = True, minid: Union[StreamIdT, None] = None, limit: Optional[int] = None, + ref_policy: Optional[Literal["KEEPREF", "DELREF", "ACKED"]] = None, ) -> ResponseT: """ Trims old messages from a stream. @@ -4044,6 +4097,10 @@ def xtrim( minid: the minimum id in the stream to query Can't be specified with maxlen. limit: specifies the maximum number of entries to retrieve + ref_policy: optional reference policy for consumer groups: + - KEEPREF (default): Trims entries but preserves references in consumer groups' PEL + - DELREF: Trims entries and removes all references from consumer groups' PEL + - ACKED: Only trims entries that were read and acknowledged by all consumer groups For more information see https://redis.io/commands/xtrim """ @@ -4054,6 +4111,9 @@ def xtrim( if maxlen is None and minid is None: raise DataError("One of ``maxlen`` or ``minid`` must be specified") + if ref_policy is not None and ref_policy not in {"KEEPREF", "DELREF", "ACKED"}: + raise DataError("XTRIM ref_policy must be one of: KEEPREF, DELREF, ACKED") + if maxlen is not None: pieces.append(b"MAXLEN") if minid is not None: @@ -4067,6 +4127,8 @@ def xtrim( if limit is not None: pieces.append(b"LIMIT") pieces.append(limit) + if ref_policy is not None: + pieces.append(ref_policy) return self.execute_command("XTRIM", name, *pieces) diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 8af7777f19..45cd403e49 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -181,7 +181,7 @@ def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs): ``name`` is the name of the field. - ``algorithm`` can be "FLAT" or "HNSW". + ``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA". ``attributes`` each algorithm can have specific attributes. Some of them are mandatory and some of them are optional. See @@ -194,10 +194,10 @@ def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs): if sort or noindex: raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.") - if algorithm.upper() not in ["FLAT", "HNSW"]: + if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]: raise DataError( - "Realtime vector indexing supporting 2 Indexing Methods:" - "'FLAT' and 'HNSW'." + "Realtime vector indexing supporting 3 Indexing Methods:" + "'FLAT', 'HNSW', and 'SVS-VAMANA'." ) attr_li = [] diff --git a/redis/commands/sentinel.py b/redis/commands/sentinel.py index f745757955..b2879b2015 100644 --- a/redis/commands/sentinel.py +++ b/redis/commands/sentinel.py @@ -11,16 +11,35 @@ def sentinel(self, *args): """Redis Sentinel's SENTINEL command.""" warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) - def sentinel_get_master_addr_by_name(self, service_name): - """Returns a (host, port) pair for the given ``service_name``""" - return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) - - def sentinel_master(self, service_name): - """Returns a dictionary containing the specified masters state.""" - return self.execute_command("SENTINEL MASTER", service_name) + def sentinel_get_master_addr_by_name(self, service_name, return_responses=False): + """ + Returns a (host, port) pair for the given ``service_name`` when return_responses is True, + otherwise returns a boolean value that indicates if the command was successful. + """ + return self.execute_command( + "SENTINEL GET-MASTER-ADDR-BY-NAME", + service_name, + once=True, + return_responses=return_responses, + ) + + def sentinel_master(self, service_name, return_responses=False): + """ + Returns a dictionary containing the specified masters state, when return_responses is True, + otherwise returns a boolean value that indicates if the command was successful. + """ + return self.execute_command( + "SENTINEL MASTER", service_name, return_responses=return_responses + ) def sentinel_masters(self): - """Returns a list of dictionaries containing each master's state.""" + """ + Returns a list of dictionaries containing each master's state. + + Important: This function is called by the Sentinel implementation and is + called directly on the Redis standalone client for sentinels, + so it doesn't support the "once" and "return_responses" options. + """ return self.execute_command("SENTINEL MASTERS") def sentinel_monitor(self, name, ip, port, quorum): @@ -31,16 +50,27 @@ def sentinel_remove(self, name): """Remove a master from Sentinel's monitoring""" return self.execute_command("SENTINEL REMOVE", name) - def sentinel_sentinels(self, service_name): - """Returns a list of sentinels for ``service_name``""" - return self.execute_command("SENTINEL SENTINELS", service_name) + def sentinel_sentinels(self, service_name, return_responses=False): + """ + Returns a list of sentinels for ``service_name``, when return_responses is True, + otherwise returns a boolean value that indicates if the command was successful. + """ + return self.execute_command( + "SENTINEL SENTINELS", service_name, return_responses=return_responses + ) def sentinel_set(self, name, option, value): """Set Sentinel monitoring parameters for a given master""" return self.execute_command("SENTINEL SET", name, option, value) def sentinel_slaves(self, service_name): - """Returns a list of slaves for ``service_name``""" + """ + Returns a list of slaves for ``service_name`` + + Important: This function is called by the Sentinel implementation and is + called directly on the Redis standalone client for sentinels, + so it doesn't support the "once" and "return_responses" options. + """ return self.execute_command("SENTINEL SLAVES", service_name) def sentinel_reset(self, pattern): diff --git a/redis/connection.py b/redis/connection.py index e87e7976a1..a6fe9234a7 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -31,6 +31,7 @@ ChildDeadlockedError, ConnectionError, DataError, + MaxConnectionsError, RedisError, ResponseError, TimeoutError, @@ -378,13 +379,18 @@ def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) - def connect_check_health(self, check_health: bool = True): + def connect_check_health( + self, check_health: bool = True, retry_socket_connect: bool = True + ): if self._sock: return try: - sock = self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect(error) - ) + if retry_socket_connect: + sock = self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect(error) + ) + else: + sock = self._connect() except socket.timeout: raise TimeoutError("Timeout connecting to server") except OSError as e: @@ -1315,6 +1321,7 @@ class ConnectionPool: By default, TCP connections are created unless ``connection_class`` is specified. Use class:`.UnixDomainSocketConnection` for unix sockets. + :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. Any additional keyword arguments are passed to the constructor of ``connection_class``. @@ -1432,10 +1439,12 @@ def __init__( self.reset() - def __repr__(self) -> (str, str): + def __repr__(self) -> str: + conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) return ( - f"<{type(self).__module__}.{type(self).__name__}" - f"({repr(self.connection_class(**self.connection_kwargs))})>" + f"<{self.__class__.__module__}.{self.__class__.__name__}" + f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" + f"({conn_kwargs})>)>" ) def get_protocol(self): @@ -1560,7 +1569,7 @@ def get_encoder(self) -> Encoder: def make_connection(self) -> "ConnectionInterface": "Create a new connection" if self._created_connections >= self.max_connections: - raise ConnectionError("Too many connections") + raise MaxConnectionsError("Too many connections") self._created_connections += 1 if self.cache is not None: diff --git a/redis/exceptions.py b/redis/exceptions.py index a00ac65ac1..643444986b 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -220,7 +220,13 @@ class SlotNotCoveredError(RedisClusterException): pass -class MaxConnectionsError(ConnectionError): ... +class MaxConnectionsError(ConnectionError): + """ + Raised when a connection pool has reached its max_connections limit. + This indicates pool exhaustion rather than an actual connection failure. + """ + + pass class CrossSlotTransactionError(RedisClusterException): diff --git a/redis/retry.py b/redis/retry.py index c93f34e65f..75778635e8 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,27 +1,27 @@ +import abc import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, TimeoutError T = TypeVar("T") +E = TypeVar("E", bound=Exception, covariant=True) if TYPE_CHECKING: from redis.backoff import AbstractBackoff -class Retry: +class AbstractRetry(Generic[E], abc.ABC): """Retry a specific number of times after a failure""" + _supported_errors: Tuple[Type[E], ...] + def __init__( self, backoff: "AbstractBackoff", retries: int, - supported_errors: Tuple[Type[Exception], ...] = ( - ConnectionError, - TimeoutError, - socket.timeout, - ), + supported_errors: Tuple[Type[E], ...], ): """ Initialize a `Retry` object with a `Backoff` object @@ -34,22 +34,14 @@ def __init__( self._retries = retries self._supported_errors = supported_errors + @abc.abstractmethod def __eq__(self, other: Any) -> bool: - if not isinstance(other, Retry): - return NotImplemented - - return ( - self._backoff == other._backoff - and self._retries == other._retries - and set(self._supported_errors) == set(other._supported_errors) - ) + return NotImplemented def __hash__(self) -> int: return hash((self._backoff, self._retries, frozenset(self._supported_errors))) - def update_supported_errors( - self, specified_errors: Iterable[Type[Exception]] - ) -> None: + def update_supported_errors(self, specified_errors: Iterable[Type[E]]) -> None: """ Updates the supported errors with the specified error types """ @@ -69,6 +61,32 @@ def update_retries(self, value: int) -> None: """ self._retries = value + +class Retry(AbstractRetry[Exception]): + __hash__ = AbstractRetry.__hash__ + + def __init__( + self, + backoff: "AbstractBackoff", + retries: int, + supported_errors: Tuple[Type[Exception], ...] = ( + ConnectionError, + TimeoutError, + socket.timeout, + ), + ): + super().__init__(backoff, retries, supported_errors) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Retry): + return NotImplemented + + return ( + self._backoff == other._backoff + and self._retries == other._retries + and set(self._supported_errors) == set(other._supported_errors) + ) + def call_with_retry( self, do: Callable[[], T], diff --git a/redis/sentinel.py b/redis/sentinel.py index 02aa244ede..f12bd8dd5d 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -5,8 +5,12 @@ from redis.client import Redis from redis.commands import SentinelCommands from redis.connection import Connection, ConnectionPool, SSLConnection -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + ResponseError, + TimeoutError, +) class MasterNotFoundError(ConnectionError): @@ -35,11 +39,11 @@ def __repr__(self): def connect_to(self, address): self.host, self.port = address - super().connect() - if self.connection_pool.check_connection: - self.send_command("PING") - if str_if_bytes(self.read_response()) != "PONG": - raise ConnectionError("PING failed") + + self.connect_check_health( + check_health=self.connection_pool.check_connection, + retry_socket_connect=False, + ) def _connect_retry(self): if self._sock: @@ -254,16 +258,27 @@ def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - once = bool(kwargs.get("once", False)) - if "once" in kwargs.keys(): - kwargs.pop("once") + once = bool(kwargs.pop("once", False)) + + # Check if command is supposed to return the original + # responses instead of boolean value. + return_responses = bool(kwargs.pop("return_responses", False)) if once: - random.choice(self.sentinels).execute_command(*args, **kwargs) - else: - for sentinel in self.sentinels: - sentinel.execute_command(*args, **kwargs) - return True + response = random.choice(self.sentinels).execute_command(*args, **kwargs) + if return_responses: + return [response] + else: + return True if response else False + + responses = [] + for sentinel in self.sentinels: + responses.append(sentinel.execute_command(*args, **kwargs)) + + if return_responses: + return responses + + return all(responses) def __repr__(self): sentinel_addresses = [] diff --git a/redis/utils.py b/redis/utils.py index 715913e914..79c23c8bda 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,9 +1,10 @@ import datetime import logging import textwrap +from collections.abc import Callable from contextlib import contextmanager from functools import wraps -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union from redis.exceptions import DataError from redis.typing import AbsExpiryT, EncodableT, ExpiryT @@ -150,18 +151,21 @@ def warn_deprecated_arg_usage( warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) +C = TypeVar("C", bound=Callable) + + def deprecated_args( args_to_warn: list = ["*"], allowed_args: list = [], reason: str = "", version: str = "", -): +) -> Callable[[C], C]: """ Decorator to mark specified args of a function as deprecated. If '*' is in args_to_warn, all arguments will be marked as deprecated. """ - def decorator(func): + def decorator(func: C) -> C: @wraps(func) def wrapper(*args, **kwargs): # Get function argument names diff --git a/tests/test_asyncio/compat.py b/tests/test_asyncio/compat.py index aa1dc49af0..97c62c53c3 100644 --- a/tests/test_asyncio/compat.py +++ b/tests/test_asyncio/compat.py @@ -1,10 +1,4 @@ import asyncio -from unittest import mock - -try: - mock.AsyncMock -except AttributeError: - from unittest import mock try: from contextlib import aclosing diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 340d146ea3..7ebc2190df 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,6 +1,7 @@ import random from contextlib import asynccontextmanager as _asynccontextmanager from typing import Union +from unittest import mock import pytest import pytest_asyncio @@ -14,8 +15,6 @@ from redis.credentials import CredentialProvider from tests.conftest import REDIS_INFO, get_credential_provider -from .compat import mock - async def _get_info(redis_url): client = redis.Redis.from_url(redis_url) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 7f87131c7a..25f487fe4c 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -4,6 +4,7 @@ import ssl import warnings from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union +from unittest import mock from urllib.parse import urlparse import pytest @@ -48,7 +49,7 @@ ) from ..ssl_utils import get_tls_certificates -from .compat import aclosing, mock +from .compat import aclosing pytestmark = pytest.mark.onlycluster diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index bfb6855a0f..9db7d200e6 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -879,6 +879,103 @@ async def test_bitop_string_operands(self, r: redis.Redis): assert int(binascii.hexlify(await r.get("res2")), 16) == 0x0102FFFF assert int(binascii.hexlify(await r.get("res3")), 16) == 0x000000FF + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_diff(self, r: redis.Redis): + await r.set("a", b"\xf0") + await r.set("b", b"\xc0") + await r.set("c", b"\x80") + + result = await r.bitop("DIFF", "result", "a", "b", "c") + assert result == 1 + assert await r.get("result") == b"\x30" + + await r.bitop("DIFF", "result2", "a", "nonexistent") + assert await r.get("result2") == b"\xf0" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_diff1(self, r: redis.Redis): + await r.set("a", b"\xf0") + await r.set("b", b"\xc0") + await r.set("c", b"\x80") + + result = await r.bitop("DIFF1", "result", "a", "b", "c") + assert result == 1 + assert await r.get("result") == b"\x00" + + await r.set("d", b"\x0f") + await r.set("e", b"\x03") + await r.bitop("DIFF1", "result2", "d", "e") + assert await r.get("result2") == b"\x00" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_andor(self, r: redis.Redis): + await r.set("a", b"\xf0") + await r.set("b", b"\xc0") + await r.set("c", b"\x80") + + result = await r.bitop("ANDOR", "result", "a", "b", "c") + assert result == 1 + assert await r.get("result") == b"\xc0" + + await r.set("x", b"\xf0") + await r.set("y", b"\x0f") + await r.bitop("ANDOR", "result2", "x", "y") + assert await r.get("result2") == b"\x00" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_one(self, r: redis.Redis): + await r.set("a", b"\xf0") + await r.set("b", b"\xc0") + await r.set("c", b"\x80") + + result = await r.bitop("ONE", "result", "a", "b", "c") + assert result == 1 + assert await r.get("result") == b"\x30" + + await r.set("x", b"\xf0") + await r.set("y", b"\x0f") + await r.bitop("ONE", "result2", "x", "y") + assert await r.get("result2") == b"\xff" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_new_operations_with_empty_keys(self, r: redis.Redis): + await r.set("a", b"\xff") + + await r.bitop("DIFF", "empty_result", "nonexistent", "a") + assert await r.get("empty_result") == b"\x00" + + await r.bitop("DIFF1", "empty_result2", "a", "nonexistent") + assert await r.get("empty_result2") == b"\x00" + + await r.bitop("ANDOR", "empty_result3", "a", "nonexistent") + assert await r.get("empty_result3") == b"\x00" + + await r.bitop("ONE", "empty_result4", "nonexistent") + assert await r.get("empty_result4") is None + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_new_operations_return_values(self, r: redis.Redis): + await r.set("a", b"\xff\x00\xff") + await r.set("b", b"\x00\xff") + + result1 = await r.bitop("DIFF", "result1", "a", "b") + assert result1 == 3 + + result2 = await r.bitop("DIFF1", "result2", "a", "b") + assert result2 == 3 + + result3 = await r.bitop("ANDOR", "result3", "a", "b") + assert result3 == 3 + + result4 = await r.bitop("ONE", "result4", "a", "b") + assert result4 == 3 + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.7") async def test_bitpos(self, r: redis.Redis): @@ -3368,6 +3465,156 @@ async def test_xtrim(self, r: redis.Redis): # 1 message is trimmed assert await r.xtrim(stream, 3, approximate=False) == 1 + @skip_if_server_version_lt("8.1.224") + async def test_xdelex(self, r: redis.Redis): + stream = "stream" + + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"foo": "bar"}) + m3 = await r.xadd(stream, {"foo": "bar"}) + m4 = await r.xadd(stream, {"foo": "bar"}) + + # Test XDELEX with default ref_policy (KEEPREF) + result = await r.xdelex(stream, m1) + assert result == [1] + + # Test XDELEX with explicit KEEPREF + result = await r.xdelex(stream, m2, ref_policy="KEEPREF") + assert result == [1] + + # Test XDELEX with DELREF + result = await r.xdelex(stream, m3, ref_policy="DELREF") + assert result == [1] + + # Test XDELEX with ACKED + result = await r.xdelex(stream, m4, ref_policy="ACKED") + assert result == [1] + + # Test with non-existent ID + result = await r.xdelex(stream, "999999-0", ref_policy="KEEPREF") + assert result == [-1] + + # Test with multiple IDs + m5 = await r.xadd(stream, {"foo": "bar"}) + m6 = await r.xadd(stream, {"foo": "bar"}) + result = await r.xdelex(stream, m5, m6, ref_policy="KEEPREF") + assert result == [1, 1] + + # Test error cases + with pytest.raises(redis.DataError): + await r.xdelex(stream, "123-0", ref_policy="INVALID") + + with pytest.raises(redis.DataError): + await r.xdelex(stream) # No IDs provided + + @skip_if_server_version_lt("8.1.224") + async def test_xackdel(self, r: redis.Redis): + stream = "stream" + group = "group" + consumer = "consumer" + + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"foo": "bar"}) + m3 = await r.xadd(stream, {"foo": "bar"}) + m4 = await r.xadd(stream, {"foo": "bar"}) + await r.xgroup_create(stream, group, 0) + + await r.xreadgroup(group, consumer, streams={stream: ">"}) + + # Test XACKDEL with default ref_policy (KEEPREF) + result = await r.xackdel(stream, group, m1) + assert result == [1] + + # Test XACKDEL with explicit KEEPREF + result = await r.xackdel(stream, group, m2, ref_policy="KEEPREF") + assert result == [1] + + # Test XACKDEL with DELREF + result = await r.xackdel(stream, group, m3, ref_policy="DELREF") + assert result == [1] + + # Test XACKDEL with ACKED + result = await r.xackdel(stream, group, m4, ref_policy="ACKED") + assert result == [1] + + # Test with non-existent ID + result = await r.xackdel(stream, group, "999999-0", ref_policy="KEEPREF") + assert result == [-1] + + # Test error cases + with pytest.raises(redis.DataError): + await r.xackdel(stream, group, m1, ref_policy="INVALID") + + with pytest.raises(redis.DataError): + await r.xackdel(stream, group) # No IDs provided + + @skip_if_server_version_lt("8.1.224") + async def test_xtrim_with_options(self, r: redis.Redis): + stream = "stream" + + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + + # Test XTRIM with KEEPREF ref_policy + assert ( + await r.xtrim(stream, maxlen=2, approximate=False, ref_policy="KEEPREF") + == 2 + ) + + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + + # Test XTRIM with DELREF ref_policy + assert ( + await r.xtrim(stream, maxlen=2, approximate=False, ref_policy="DELREF") == 2 + ) + + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + + # Test XTRIM with ACKED ref_policy + assert ( + await r.xtrim(stream, maxlen=2, approximate=False, ref_policy="ACKED") == 2 + ) + + # Test error case + with pytest.raises(redis.DataError): + await r.xtrim(stream, maxlen=2, ref_policy="INVALID") + + @skip_if_server_version_lt("8.1.224") + async def test_xadd_with_options(self, r: redis.Redis): + stream = "stream" + + # Test XADD with KEEPREF ref_policy + await r.xadd( + stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="KEEPREF" + ) + await r.xadd( + stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="KEEPREF" + ) + await r.xadd( + stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="KEEPREF" + ) + assert await r.xlen(stream) == 2 + + # Test XADD with DELREF ref_policy + await r.xadd( + stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="DELREF" + ) + assert await r.xlen(stream) == 2 + + # Test XADD with ACKED ref_policy + await r.xadd( + stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="ACKED" + ) + assert await r.xlen(stream) == 2 + + # Test error case + with pytest.raises(redis.DataError): + await r.xadd(stream, {"foo": "bar"}, ref_policy="INVALID") + @pytest.mark.onlynoncluster async def test_bitfield_operations(self, r: redis.Redis): # comments show affected bits diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 38764d30cd..35d404a36e 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -1,6 +1,7 @@ import asyncio import socket import types +from unittest import mock from errno import ECONNREFUSED from unittest.mock import patch @@ -25,7 +26,6 @@ from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt -from .compat import mock from .mocks import MockStream diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 09409e04a8..c30220fb1d 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -1,5 +1,6 @@ import asyncio import re +from unittest import mock import pytest import pytest_asyncio @@ -8,7 +9,7 @@ from redis.auth.token import TokenInterface from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt -from .compat import aclosing, mock +from .compat import aclosing from .conftest import asynccontextmanager from .test_pubsub import wait_for_message @@ -294,13 +295,14 @@ def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool( host="localhost", port=6379, client_name="test-client" ) - expected = "host=localhost,port=6379,db=0,client_name=test-client" + expected = "host=localhost,port=6379,client_name=test-client" assert expected in repr(pool) def test_repr_contains_db_info_unix(self): pool = redis.ConnectionPool( connection_class=redis.UnixDomainSocketConnection, path="abc", + db=0, client_name="test-client", ) expected = "path=abc,db=0,client_name=test-client" @@ -651,7 +653,7 @@ async def test_oom_error(self, r): await r.execute_command("DEBUG", "ERROR", "OOM blah blah") def test_connect_from_url_tcp(self): - connection = redis.Redis.from_url("redis://localhost") + connection = redis.Redis.from_url("redis://localhost:6379?db=0") pool = connection.connection_pool assert re.match( @@ -659,7 +661,7 @@ def test_connect_from_url_tcp(self): ).groups() == ( "ConnectionPool", "Connection", - "host=localhost,port=6379,db=0", + "db=0,host=localhost,port=6379", ) def test_connect_from_url_unix(self): @@ -671,7 +673,7 @@ def test_connect_from_url_unix(self): ).groups() == ( "ConnectionPool", "UnixDomainSocketConnection", - "path=/path/to/socket,db=0", + "path=/path/to/socket", ) @skip_if_redis_enterprise() diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 19e11dc792..1f8c743b48 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -1,8 +1,10 @@ +from unittest import mock + import pytest import redis from tests.conftest import skip_if_server_version_lt -from .compat import aclosing, mock +from .compat import aclosing from .conftest import wait_for_command diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index d193cc9f2d..b281cb1281 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -12,6 +12,8 @@ else: from async_timeout import timeout as async_timeout +from unittest import mock + import pytest import pytest_asyncio import redis.asyncio as redis @@ -19,7 +21,7 @@ from redis.typing import EncodableT from tests.conftest import get_protocol_version, skip_if_server_version_lt -from .compat import aclosing, create_task, mock +from .compat import aclosing, create_task def with_timeout(t): diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 932ece59b8..0004f9ba75 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1815,3 +1815,181 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis): assert docs[0]["first_name"] == mixed_data["first_name"], ( "The text field is not decoded correctly" ) + + +# SVS-VAMANA Async Tests +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_basic_functionality(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [10.0, 11.0, 12.0, 13.0], + ] + + for i, vec in enumerate(vectors): + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = "*=>[KNN 3 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = await decoded_r.ft().search( + q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + ) + + if is_resp2_connection(decoded_r): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_distance_metrics(decoded_r: redis.Redis): + # Test COSINE distance + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_vector_types(decoded_r: redis.Redis): + # Test FLOAT16 + await decoded_r.ft("idx16").create_index( + ( + VectorField( + "v16", + "SVS-VAMANA", + {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] + + for i, vec in enumerate(vectors): + await decoded_r.hset( + f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes() + ) + + query = Query("*=>[KNN 2 @v16 $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = await decoded_r.ft("idx16").search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 2 + assert "doc16_0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc16_0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_compression(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_build_parameters(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "CONSTRUCTION_WINDOW_SIZE": 300, + "GRAPH_MAX_DEGREE": 64, + "SEARCH_WINDOW_SIZE": 20, + "EPSILON": 0.05, + }, + ), + ) + ) + + vectors = [] + for i in range(15): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index a27ba92bb8..867ff15405 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -13,7 +13,7 @@ ) -@pytest_asyncio.fixture(scope="module") +@pytest_asyncio.fixture(scope="module", loop_scope="module") def master_ip(master_host): yield socket.gethostbyname(master_host[0]) @@ -84,6 +84,35 @@ def sentinel(request, cluster): return Sentinel([("foo", 26379), ("bar", 26379)]) +@pytest.fixture() +async def deployed_sentinel(request): + sentinel_ips = request.config.getoption("--sentinels") + sentinel_endpoints = [ + (ip.strip(), int(port.strip())) + for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) + ] + kwargs = {} + decode_responses = True + + sentinel_kwargs = {"decode_responses": decode_responses} + force_master_ip = "localhost" + + protocol = request.config.getoption("--protocol", 2) + + sentinel = Sentinel( + sentinel_endpoints, + force_master_ip=force_master_ip, + sentinel_kwargs=sentinel_kwargs, + socket_timeout=0.1, + protocol=protocol, + decode_responses=decode_responses, + **kwargs, + ) + yield sentinel + for s in sentinel.sentinels: + await s.close() + + @pytest.mark.onlynoncluster async def test_discover_master(sentinel, master_ip): address = await sentinel.discover_master("mymaster") @@ -226,19 +255,22 @@ async def test_slave_round_robin(cluster, sentinel, master_ip): @pytest.mark.onlynoncluster -async def test_ckquorum(cluster, sentinel): - assert await sentinel.sentinel_ckquorum("mymaster") +async def test_ckquorum(sentinel): + resp = await sentinel.sentinel_ckquorum("mymaster") + assert resp is True @pytest.mark.onlynoncluster -async def test_flushconfig(cluster, sentinel): - assert await sentinel.sentinel_flushconfig() +async def test_flushconfig(sentinel): + resp = await sentinel.sentinel_flushconfig() + assert resp is True @pytest.mark.onlynoncluster async def test_reset(cluster, sentinel): cluster.master["is_odown"] = True - assert await sentinel.sentinel_reset("mymaster") + resp = await sentinel.sentinel_reset("mymaster") + assert resp is True @pytest.mark.onlynoncluster @@ -284,3 +316,50 @@ async def test_repr_correctly_represents_connection_object(sentinel): str(connection) == "" # noqa: E501 ) + + +# Tests against real sentinel instances +@pytest.mark.onlynoncluster +async def test_get_sentinels(deployed_sentinel): + resps = await deployed_sentinel.sentinel_sentinels( + "redis-py-test", return_responses=True + ) + + # validate that the original command response is returned + assert isinstance(resps, list) + + # validate that the command has been executed against all sentinels + # each response from each sentinel is returned + assert len(resps) > 1 + + # validate default behavior + resps = await deployed_sentinel.sentinel_sentinels("redis-py-test") + assert isinstance(resps, bool) + + +@pytest.mark.onlynoncluster +async def test_get_master_addr_by_name(deployed_sentinel): + resps = await deployed_sentinel.sentinel_get_master_addr_by_name( + "redis-py-test", + return_responses=True, + ) + + # validate that the original command response is returned + assert isinstance(resps, list) + + # validate that the command has been executed just once + # when executed once, only one response element is returned + assert len(resps) == 1 + + assert isinstance(resps[0], tuple) + + # validate default behavior + resps = await deployed_sentinel.sentinel_get_master_addr_by_name("redis-py-test") + assert isinstance(resps, bool) + + +@pytest.mark.onlynoncluster +async def test_redis_master_usage(deployed_sentinel): + r = await deployed_sentinel.master_for("redis-py-test", db=0) + await r.set("foo", "bar") + assert (await r.get("foo")) == "bar" diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index cae4b9581f..5a511b2793 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -1,12 +1,11 @@ import socket +from unittest import mock import pytest from redis.asyncio.retry import Retry from redis.asyncio.sentinel import SentinelManagedConnection from redis.backoff import NoBackoff -from .compat import mock - pytestmark = pytest.mark.asyncio @@ -34,4 +33,5 @@ async def mock_connect(): conn._connect.side_effect = mock_connect await conn.connect() assert conn._connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 await conn.disconnect() diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d360ab07f7..4883ba66c9 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -111,9 +111,13 @@ class NodeProxy: def __init__(self, addr, redis_addr): self.addr = addr self.redis_addr = redis_addr - self.server = socketserver.ThreadingTCPServer(self.addr, ProxyRequestHandler) + self.server = socketserver.ThreadingTCPServer( + self.addr, ProxyRequestHandler, bind_and_activate=False + ) self.server.proxy = self - self.server.socket_reuse_address = True + self.server.allow_reuse_address = True + self.server.server_bind() + self.server.server_activate() self.thread = None self.n_connections = 0 diff --git a/tests/test_commands.py b/tests/test_commands.py index 8758efa771..42530a47d2 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1088,6 +1088,7 @@ def test_lastsave(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("5.0.0") + @skip_if_server_version_gte("8.0.0") def test_lolwut(self, r): lolwut = r.lolwut().decode("utf-8") assert "Redis ver." in lolwut @@ -1095,6 +1096,15 @@ def test_lolwut(self, r): lolwut = r.lolwut(5, 6, 7, 8).decode("utf-8") assert "Redis ver." in lolwut + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.0.0") + def test_lolwut_v8_and_higher(self, r): + lolwut = r.lolwut().decode("utf-8") + assert lolwut + + lolwut = r.lolwut(5, 6, 7, 8).decode("utf-8") + assert lolwut + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.2.0") @skip_if_redis_enterprise() @@ -1303,6 +1313,103 @@ def test_bitop_string_operands(self, r): assert int(binascii.hexlify(r["res2"]), 16) == 0x0102FFFF assert int(binascii.hexlify(r["res3"]), 16) == 0x000000FF + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_diff(self, r): + r["a"] = b"\xf0" + r["b"] = b"\xc0" + r["c"] = b"\x80" + + result = r.bitop("DIFF", "result", "a", "b", "c") + assert result == 1 + assert r["result"] == b"\x30" + + r.bitop("DIFF", "result2", "a", "nonexistent") + assert r["result2"] == b"\xf0" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_diff1(self, r): + r["a"] = b"\xf0" + r["b"] = b"\xc0" + r["c"] = b"\x80" + + result = r.bitop("DIFF1", "result", "a", "b", "c") + assert result == 1 + assert r["result"] == b"\x00" + + r["d"] = b"\x0f" + r["e"] = b"\x03" + r.bitop("DIFF1", "result2", "d", "e") + assert r["result2"] == b"\x00" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_andor(self, r): + r["a"] = b"\xf0" + r["b"] = b"\xc0" + r["c"] = b"\x80" + + result = r.bitop("ANDOR", "result", "a", "b", "c") + assert result == 1 + assert r["result"] == b"\xc0" + + r["x"] = b"\xf0" + r["y"] = b"\x0f" + r.bitop("ANDOR", "result2", "x", "y") + assert r["result2"] == b"\x00" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_one(self, r): + r["a"] = b"\xf0" + r["b"] = b"\xc0" + r["c"] = b"\x80" + + result = r.bitop("ONE", "result", "a", "b", "c") + assert result == 1 + assert r["result"] == b"\x30" + + r["x"] = b"\xf0" + r["y"] = b"\x0f" + r.bitop("ONE", "result2", "x", "y") + assert r["result2"] == b"\xff" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_new_operations_with_empty_keys(self, r): + r["a"] = b"\xff" + + r.bitop("DIFF", "empty_result", "nonexistent", "a") + assert r.get("empty_result") == b"\x00" + + r.bitop("DIFF1", "empty_result2", "a", "nonexistent") + assert r.get("empty_result2") == b"\x00" + + r.bitop("ANDOR", "empty_result3", "a", "nonexistent") + assert r.get("empty_result3") == b"\x00" + + r.bitop("ONE", "empty_result4", "nonexistent") + assert r.get("empty_result4") is None + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_new_operations_return_values(self, r): + r["a"] = b"\xff\x00\xff" + r["b"] = b"\x00\xff" + + result1 = r.bitop("DIFF", "result1", "a", "b") + assert result1 == 3 + + result2 = r.bitop("DIFF1", "result2", "a", "b") + assert result2 == 3 + + result3 = r.bitop("ANDOR", "result3", "a", "b") + assert result3 == 3 + + result4 = r.bitop("ONE", "result4", "a", "b") + assert result4 == 3 + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.7") def test_bitpos(self, r): @@ -4989,6 +5096,145 @@ def test_xtrim_minlen_and_length_args(self, r): r.xadd(stream, {"foo": "bar"}) assert r.xtrim(stream, None, approximate=True, minid=m3) == 0 + @skip_if_server_version_lt("8.1.224") + def test_xdelex(self, r): + stream = "stream" + + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"foo": "bar"}) + m3 = r.xadd(stream, {"foo": "bar"}) + m4 = r.xadd(stream, {"foo": "bar"}) + + # Test XDELEX with default ref_policy (KEEPREF) + result = r.xdelex(stream, m1) + assert result == [1] + + # Test XDELEX with explicit KEEPREF + result = r.xdelex(stream, m2, ref_policy="KEEPREF") + assert result == [1] + + # Test XDELEX with DELREF + result = r.xdelex(stream, m3, ref_policy="DELREF") + assert result == [1] + + # Test XDELEX with ACKED + result = r.xdelex(stream, m4, ref_policy="ACKED") + assert result == [1] + + # Test with non-existent ID + result = r.xdelex(stream, "999999-0", ref_policy="KEEPREF") + assert result == [-1] + + # Test with multiple IDs + m5 = r.xadd(stream, {"foo": "bar"}) + m6 = r.xadd(stream, {"foo": "bar"}) + result = r.xdelex(stream, m5, m6, ref_policy="KEEPREF") + assert result == [1, 1] # Both entries deleted + + # Test error cases + with pytest.raises(redis.DataError): + r.xdelex(stream, "123-0", ref_policy="INVALID") + + with pytest.raises(redis.DataError): + r.xdelex(stream) # No IDs provided + + @skip_if_server_version_lt("8.1.224") + def test_xackdel(self, r): + stream = "stream" + group = "group" + consumer = "consumer" + + m1 = r.xadd(stream, {"foo": "bar"}) + m2 = r.xadd(stream, {"foo": "bar"}) + m3 = r.xadd(stream, {"foo": "bar"}) + m4 = r.xadd(stream, {"foo": "bar"}) + r.xgroup_create(stream, group, 0) + + r.xreadgroup(group, consumer, streams={stream: ">"}) + + # Test XACKDEL with default ref_policy (KEEPREF) + result = r.xackdel(stream, group, m1) + assert result == [1] + + # Test XACKDEL with explicit KEEPREF + result = r.xackdel(stream, group, m2, ref_policy="KEEPREF") + assert result == [1] + + # Test XACKDEL with DELREF + result = r.xackdel(stream, group, m3, ref_policy="DELREF") + assert result == [1] + + # Test XACKDEL with ACKED + result = r.xackdel(stream, group, m4, ref_policy="ACKED") + assert result == [1] + + # Test with non-existent ID + result = r.xackdel(stream, group, "999999-0", ref_policy="KEEPREF") + assert result == [-1] + + # Test error cases + with pytest.raises(redis.DataError): + r.xackdel(stream, group, m1, ref_policy="INVALID") + + with pytest.raises(redis.DataError): + r.xackdel(stream, group) # No IDs provided + + @skip_if_server_version_lt("8.1.224") + def test_xtrim_with_options(self, r): + stream = "stream" + + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + + # Test XTRIM with KEEPREF ref_policy + assert r.xtrim(stream, maxlen=2, approximate=False, ref_policy="KEEPREF") == 2 + + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + + # Test XTRIM with DELREF ref_policy + assert r.xtrim(stream, maxlen=2, approximate=False, ref_policy="DELREF") == 2 + + r.xadd(stream, {"foo": "bar"}) + r.xadd(stream, {"foo": "bar"}) + + # Test XTRIM with ACKED ref_policy + assert r.xtrim(stream, maxlen=2, approximate=False, ref_policy="ACKED") == 2 + + # Test error case + with pytest.raises(redis.DataError): + r.xtrim(stream, maxlen=2, ref_policy="INVALID") + + @skip_if_server_version_lt("8.1.224") + def test_xadd_with_options(self, r): + stream = "stream" + + # Test XADD with KEEPREF ref_policy + r.xadd( + stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="KEEPREF" + ) + r.xadd( + stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="KEEPREF" + ) + r.xadd( + stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="KEEPREF" + ) + assert r.xlen(stream) == 2 + + # Test XADD with DELREF ref_policy + r.xadd(stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="DELREF") + assert r.xlen(stream) == 2 + + # Test XADD with ACKED ref_policy + r.xadd(stream, {"foo": "bar"}, maxlen=2, approximate=False, ref_policy="ACKED") + assert r.xlen(stream) == 2 + + # Test error case + with pytest.raises(redis.DataError): + r.xadd(stream, {"foo": "bar"}, ref_policy="INVALID") + def test_bitfield_operations(self, r): # comments show affected bits bf = r.bitfield("a") diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 9e67659fa9..3a4896f2a3 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -76,11 +76,13 @@ def test_multiple_connections(self, master_host): assert c1 != c2 def test_max_connections(self, master_host): - connection_kwargs = {"host": master_host[0], "port": master_host[1]} - pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) + # Use DummyConnection to avoid actual connection to Redis + # This prevents authentication issues and makes the test more reliable + # while still properly testing the MaxConnectionsError behavior + pool = self.get_pool(max_connections=2, connection_class=DummyConnection) pool.get_connection() pool.get_connection() - with pytest.raises(redis.ConnectionError): + with pytest.raises(redis.MaxConnectionsError): pool.get_connection() def test_reuse_previously_released_connection(self, master_host): @@ -205,13 +207,14 @@ def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool( host="localhost", port=6379, client_name="test-client" ) - expected = "host=localhost,port=6379,db=0,client_name=test-client" + expected = "host=localhost,port=6379,client_name=test-client" assert expected in repr(pool) def test_repr_contains_db_info_unix(self): pool = redis.ConnectionPool( connection_class=redis.UnixDomainSocketConnection, path="abc", + db=0, client_name="test-client", ) expected = "path=abc,db=0,client_name=test-client" @@ -598,7 +601,7 @@ def test_oom_error(self, r): r.execute_command("DEBUG", "ERROR", "OOM blah blah") def test_connect_from_url_tcp(self): - connection = redis.Redis.from_url("redis://localhost") + connection = redis.Redis.from_url("redis://localhost:6379?db=0") pool = connection.connection_pool assert re.match( @@ -606,7 +609,7 @@ def test_connect_from_url_tcp(self): ).groups() == ( "ConnectionPool", "Connection", - "host=localhost,port=6379,db=0", + "db=0,host=localhost,port=6379", ) def test_connect_from_url_unix(self): @@ -618,7 +621,7 @@ def test_connect_from_url_unix(self): ).groups() == ( "ConnectionPool", "UnixDomainSocketConnection", - "path=/path/to/socket,db=0", + "path=/path/to/socket", ) @skip_if_redis_enterprise() diff --git a/tests/test_max_connections_error.py b/tests/test_max_connections_error.py new file mode 100644 index 0000000000..4a4e09f8f8 --- /dev/null +++ b/tests/test_max_connections_error.py @@ -0,0 +1,112 @@ +import pytest +import redis +from unittest import mock +from redis.connection import ConnectionInterface + + +class DummyConnection(ConnectionInterface): + """A dummy connection class for testing that doesn't actually connect to Redis""" + + def __init__(self, *args, **kwargs): + self.connected = False + + def connect(self): + self.connected = True + + def disconnect(self): + self.connected = False + + def register_connect_callback(self, callback): + pass + + def deregister_connect_callback(self, callback): + pass + + def set_parser(self, parser_class): + pass + + def get_protocol(self): + return 2 + + def on_connect(self): + pass + + def check_health(self): + return True + + def send_packed_command(self, command, check_health=True): + pass + + def send_command(self, *args, **kwargs): + pass + + def can_read(self, timeout=0): + return False + + def read_response(self, disable_decoding=False, **kwargs): + return "PONG" + + +@pytest.mark.onlynoncluster +def test_max_connections_error_inheritance(): + """Test that MaxConnectionsError is a subclass of ConnectionError""" + assert issubclass(redis.MaxConnectionsError, redis.ConnectionError) + + +@pytest.mark.onlynoncluster +def test_connection_pool_raises_max_connections_error(): + """Test that ConnectionPool raises MaxConnectionsError and not ConnectionError""" + # Use a dummy connection class that doesn't try to connect to a real Redis server + pool = redis.ConnectionPool(max_connections=1, connection_class=DummyConnection) + pool.get_connection() + + with pytest.raises(redis.MaxConnectionsError): + pool.get_connection() + + +@pytest.mark.skipif( + not hasattr(redis, "RedisCluster"), reason="RedisCluster not available" +) +def test_cluster_handles_max_connections_error(): + """ + Test that RedisCluster doesn't reinitialize when MaxConnectionsError is raised + """ + # Create a more complete mock cluster + cluster = mock.MagicMock(spec=redis.RedisCluster) + cluster.cluster_response_callbacks = {} + cluster.RedisClusterRequestTTL = 3 # Set the TTL to avoid infinite loops + cluster.nodes_manager = mock.MagicMock() + node = mock.MagicMock() + + # Mock get_redis_connection to return a mock Redis client + redis_conn = mock.MagicMock() + cluster.get_redis_connection.return_value = redis_conn + + # Setup get_connection to be called and return a connection that will raise + connection = mock.MagicMock() + + # Patch the get_connection function in the cluster module + with mock.patch("redis.cluster.get_connection", return_value=connection): + # Test MaxConnectionsError + connection.send_command.side_effect = redis.MaxConnectionsError( + "Too many connections" + ) + + # Call the method and check that the exception is raised + with pytest.raises(redis.MaxConnectionsError): + redis.RedisCluster._execute_command(cluster, node, "GET", "key") + + # Verify nodes_manager.initialize was NOT called + cluster.nodes_manager.initialize.assert_not_called() + + # Reset the mock for the next test + cluster.nodes_manager.initialize.reset_mock() + + # Now test with regular ConnectionError to ensure it DOES reinitialize + connection.send_command.side_effect = redis.ConnectionError("Connection lost") + + with pytest.raises(redis.ConnectionError): + redis.RedisCluster._execute_command(cluster, node, "GET", "key") + + # Verify nodes_manager.initialize WAS called + cluster.nodes_manager.initialize.assert_called_once() diff --git a/tests/test_retry.py b/tests/test_retry.py index 4f4f04caca..9c0ca65d81 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest +from redis.asyncio.retry import Retry as AsyncRetry from redis.backoff import ( AbstractBackoff, ConstantBackoff, @@ -89,6 +90,7 @@ def test_retry_on_error_retry(self, Class, retries): assert c.retry._retries == retries +@pytest.mark.parametrize("retry_class", [Retry, AsyncRetry]) @pytest.mark.parametrize( "args", [ @@ -108,8 +110,8 @@ def test_retry_on_error_retry(self, Class, retries): for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5)) ], ) -def test_retry_eq_and_hashable(args): - assert Retry(*args) == Retry(*args) +def test_retry_eq_and_hashable(retry_class, args): + assert retry_class(*args) == retry_class(*args) # create another retry object with different parameters copy = list(args) @@ -118,9 +120,19 @@ def test_retry_eq_and_hashable(args): else: copy[0] = ConstantBackoff(9000) - assert Retry(*args) != Retry(*copy) - assert Retry(*copy) != Retry(*args) - assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2 + assert retry_class(*args) != retry_class(*copy) + assert retry_class(*copy) != retry_class(*args) + assert ( + len( + { + retry_class(*args), + retry_class(*args), + retry_class(*copy), + retry_class(*copy), + } + ) + == 2 + ) class TestRetry: diff --git a/tests/test_search.py b/tests/test_search.py index 4af55e8a17..3460b56ca1 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -2863,6 +2863,100 @@ def test_vector_search_with_default_dialect(client): assert res["total_results"] == 2 +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_l2_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + # L2 distance test vectors + vectors = [[1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [5.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_cosine_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_ip_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "IP"}, + ), + ) + ) + + vectors = [[1.0, 2.0, 3.0], [2.0, 1.0, 1.0], [3.0, 3.0, 3.0], [0.1, 0.1, 0.1]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc2" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc2" == res["results"][0]["id"] + + @pytest.mark.redismod @skip_if_server_version_lt("7.9.0") def test_vector_search_with_int8_type(client): @@ -2878,7 +2972,7 @@ def test_vector_search_with_int8_type(client): client.hset("b", "v", np.array(b, dtype=np.int8).tobytes()) client.hset("c", "v", np.array(c, dtype=np.int8).tobytes()) - query = Query("*=>[KNN 2 @v $vec as score]") + query = Query("*=>[KNN 2 @v $vec as score]").no_content() query_params = {"vec": np.array(a, dtype=np.int8).tobytes()} assert 2 in query.get_args() @@ -2909,7 +3003,7 @@ def test_vector_search_with_uint8_type(client): client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes()) client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes()) - query = Query("*=>[KNN 2 @v $vec as score]") + query = Query("*=>[KNN 2 @v $vec as score]").no_content() query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()} assert 2 in query.get_args() @@ -2966,3 +3060,745 @@ def _assert_search_result(client, result, expected_doc_ids): assert set([doc.id for doc in result.docs]) == set(expected_doc_ids) else: assert set([doc["id"] for doc in result["results"]]) == set(expected_doc_ids) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_basic_functionality(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [10.0, 11.0, 12.0, 13.0], + ] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = "*=>[KNN 3 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search( + q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + ) + + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id # Should be closest to itself + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_float16_type(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_float32_type(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0], [3.0, 4.0, 5.0, 6.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_search_with_default_dialect(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") + + query = "*=>[KNN 2 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) + + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_field_basic(): + field = VectorField( + "v", "SVS-VAMANA", {"TYPE": "FLOAT32", "DIM": 128, "DISTANCE_METRIC": "COSINE"} + ) + + # Check that the field was created successfully + assert field.name == "v" + assert field.args[0] == "VECTOR" + assert field.args[1] == "SVS-VAMANA" + assert field.args[2] == 6 + assert "TYPE" in field.args + assert "FLOAT32" in field.args + assert "DIM" in field.args + assert 128 in field.args + assert "DISTANCE_METRIC" in field.args + assert "COSINE" in field.args + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_lvq8_compression(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_compression_with_both_vector_types(client): + # Test FLOAT16 with LVQ8 + client.ft("idx16").create_index( + ( + VectorField( + "v16", + "SVS-VAMANA", + { + "TYPE": "FLOAT16", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + # Test FLOAT32 with LVQ8 + client.ft("idx32").create_index( + ( + VectorField( + "v32", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + # Add data to both indices + for i in range(15): + vec = [float(i + j) for j in range(8)] + client.hset(f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes()) + client.hset(f"doc32_{i}", "v32", np.array(vec, dtype=np.float32).tobytes()) + + # Test both indices + query = Query("*=>[KNN 3 @v16 $vec as score]").no_content() + res16 = client.ft("idx16").search( + query, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float16 + ).tobytes() + }, + ) + + query = Query("*=>[KNN 3 @v32 $vec as score]").no_content() + res32 = client.ft("idx32").search( + query, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float32 + ).tobytes() + }, + ) + + if is_resp2_connection(client): + assert res16.total == 3 + assert res32.total == 3 + else: + assert res16["total_results"] == 3 + assert res32["total_results"] == 3 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_construction_window_size(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 300, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_graph_max_degree(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "GRAPH_MAX_DEGREE": 64, + }, + ), + ) + ) + + vectors = [] + for i in range(25): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 6 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 6 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 6 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_search_window_size(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "SEARCH_WINDOW_SIZE": 20, + }, + ), + ) + ) + + vectors = [] + for i in range(30): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 8 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 8 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 8 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_epsilon_parameter(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 6, "DISTANCE_METRIC": "L2", "EPSILON": 0.05}, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_all_build_parameters_combined(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "IP", + "CONSTRUCTION_WINDOW_SIZE": 250, + "GRAPH_MAX_DEGREE": 48, + "SEARCH_WINDOW_SIZE": 15, + "EPSILON": 0.02, + }, + ), + ) + ) + + vectors = [] + for i in range(35): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 7 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 7 + doc_ids = [doc.id for doc in res.docs] + assert len(doc_ids) == 7 + else: + assert res["total_results"] == 7 + doc_ids = [doc["id"] for doc in res["results"]] + assert len(doc_ids) == 7 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_comprehensive_configuration(client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT16", + "DIM": 32, + "DISTANCE_METRIC": "COSINE", + "COMPRESSION": "LVQ8", + "CONSTRUCTION_WINDOW_SIZE": 400, + "GRAPH_MAX_DEGREE": 96, + "SEARCH_WINDOW_SIZE": 25, + "EPSILON": 0.03, + "TRAINING_THRESHOLD": 2048, + }, + ), + ) + ) + + vectors = [] + for i in range(60): + vec = [float(i + j) for j in range(32)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + + query = Query("*=>[KNN 10 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 10 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 10 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_hybrid_text_vector_search(client): + client.flushdb() + client.ft().create_index( + ( + TextField("title"), + TextField("content"), + VectorField( + "embedding", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "SEARCH_WINDOW_SIZE": 20, + }, + ), + ) + ) + + docs = [ + { + "title": "AI Research", + "content": "machine learning algorithms", + "embedding": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + }, + { + "title": "Data Science", + "content": "statistical analysis methods", + "embedding": [2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + }, + { + "title": "Deep Learning", + "content": "neural network architectures", + "embedding": [3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + }, + { + "title": "Computer Vision", + "content": "image processing techniques", + "embedding": [10.0, 11.0, 12.0, 13.0, 14.0, 15.0], + }, + ] + + for i, doc in enumerate(docs): + client.hset( + f"doc{i}", + mapping={ + "title": doc["title"], + "content": doc["content"], + "embedding": np.array(doc["embedding"], dtype=np.float32).tobytes(), + }, + ) + + # Hybrid query - text filter + vector similarity + query = "(@title:AI|@content:machine)=>[KNN 2 @embedding $vec]" + q = ( + Query(query) + .return_field("__embedding_score") + .sort_by("__embedding_score", True) + ) + res = client.ft().search( + q, + query_params={ + "vec": np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).tobytes() + }, + ) + + if is_resp2_connection(client): + assert res.total >= 1 + doc_ids = [doc.id for doc in res.docs] + assert "doc0" in doc_ids + else: + assert res["total_results"] >= 1 + doc_ids = [doc["id"] for doc in res["results"]] + assert "doc0" in doc_ids + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_large_dimension_vectors(client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 512, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 300, + "GRAPH_MAX_DEGREE": 64, + }, + ), + ) + ) + + vectors = [] + for i in range(10): + vec = [float(i + j) for j in range(512)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_training_threshold_behavior(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + if i >= 5: + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + + if is_resp2_connection(client): + assert res.total >= 1 + else: + assert res["total_results"] >= 1 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_different_k_values(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "SEARCH_WINDOW_SIZE": 15, + }, + ), + ) + ) + + vectors = [] + for i in range(25): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + for k in [1, 3, 5, 10, 15]: + query = Query(f"*=>[KNN {k} @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + + if is_resp2_connection(client): + assert res.total == k + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == k + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_field_error(client): + # sortable tag + with pytest.raises(Exception): + client.ft().create_index((VectorField("v", "SVS-VAMANA", {}, sortable=True),)) + + # no_index tag + with pytest.raises(Exception): + client.ft().create_index((VectorField("v", "SVS-VAMANA", {}, no_index=True),)) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_search_with_parameters(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 4, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 200, + "GRAPH_MAX_DEGREE": 64, + "SEARCH_WINDOW_SIZE": 40, + "EPSILON": 0.01, + }, + ), + ) + ) + + # Create test vectors + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + ] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 93455f3290..0e7624c836 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -86,6 +86,35 @@ def sentinel(request, cluster): return Sentinel([("foo", 26379), ("bar", 26379)]) +@pytest.fixture() +def deployed_sentinel(request): + sentinel_ips = request.config.getoption("--sentinels") + sentinel_endpoints = [ + (ip.strip(), int(port.strip())) + for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) + ] + kwargs = {} + decode_responses = True + + sentinel_kwargs = {"decode_responses": decode_responses} + force_master_ip = "localhost" + + protocol = request.config.getoption("--protocol", 2) + + sentinel = Sentinel( + sentinel_endpoints, + force_master_ip=force_master_ip, + sentinel_kwargs=sentinel_kwargs, + socket_timeout=0.1, + protocol=protocol, + decode_responses=decode_responses, + **kwargs, + ) + yield sentinel + for s in sentinel.sentinels: + s.close() + + @pytest.mark.onlynoncluster def test_discover_master(sentinel, master_ip): address = sentinel.discover_master("mymaster") @@ -184,7 +213,7 @@ def test_discover_slaves(cluster, sentinel): @pytest.mark.onlynoncluster -def test_master_for(cluster, sentinel, master_ip): +def test_master_for(sentinel, master_ip): master = sentinel.master_for("mymaster", db=9) assert master.ping() assert master.connection_pool.master_address == (master_ip, 6379) @@ -228,19 +257,22 @@ def test_slave_round_robin(cluster, sentinel, master_ip): @pytest.mark.onlynoncluster -def test_ckquorum(cluster, sentinel): - assert sentinel.sentinel_ckquorum("mymaster") +def test_ckquorum(sentinel): + resp = sentinel.sentinel_ckquorum("mymaster") + assert resp is True @pytest.mark.onlynoncluster -def test_flushconfig(cluster, sentinel): - assert sentinel.sentinel_flushconfig() +def test_flushconfig(sentinel): + resp = sentinel.sentinel_flushconfig() + assert resp is True @pytest.mark.onlynoncluster def test_reset(cluster, sentinel): cluster.master["is_odown"] = True - assert sentinel.sentinel_reset("mymaster") + resp = sentinel.sentinel_reset("mymaster") + assert resp is True @pytest.mark.onlynoncluster @@ -266,3 +298,47 @@ def mock_disconnect(): assert calls == 1 pool.disconnect() + + +# Tests against real sentinel instances +@pytest.mark.onlynoncluster +def test_get_sentinels(deployed_sentinel): + resps = deployed_sentinel.sentinel_sentinels("redis-py-test", return_responses=True) + + # validate that the original command response is returned + assert isinstance(resps, list) + + # validate that the command has been executed against all sentinels + # each response from each sentinel is returned + assert len(resps) > 1 + + # validate default behavior + resps = deployed_sentinel.sentinel_sentinels("redis-py-test") + assert isinstance(resps, bool) + + +@pytest.mark.onlynoncluster +def test_get_master_addr_by_name(deployed_sentinel): + resps = deployed_sentinel.sentinel_get_master_addr_by_name( + "redis-py-test", return_responses=True + ) + + # validate that the original command response is returned + assert isinstance(resps, list) + + # validate that the command has been executed just once + # when executed once, only one response element is returned + assert len(resps) == 1 + + assert isinstance(resps[0], tuple) + + # validate default behavior + resps = deployed_sentinel.sentinel_get_master_addr_by_name("redis-py-test") + assert isinstance(resps, bool) + + +@pytest.mark.onlynoncluster +def test_redis_master_usage(deployed_sentinel): + r = deployed_sentinel.master_for("redis-py-test", db=0) + r.set("foo", "bar") + assert r.get("foo") == "bar" diff --git a/tests/test_sentinel_managed_connection.py b/tests/test_sentinel_managed_connection.py new file mode 100644 index 0000000000..6fe5f7cd5b --- /dev/null +++ b/tests/test_sentinel_managed_connection.py @@ -0,0 +1,34 @@ +import socket + +from redis.retry import Retry +from redis.sentinel import SentinelManagedConnection +from redis.backoff import NoBackoff +from unittest import mock + + +def test_connect_retry_on_timeout_error(master_host): + """Test that the _connect function is retried in case of a timeout""" + connection_pool = mock.Mock() + connection_pool.get_master_address = mock.Mock( + return_value=(master_host[0], master_host[1]) + ) + conn = SentinelManagedConnection( + retry_on_timeout=True, + retry=Retry(NoBackoff(), 3), + connection_pool=connection_pool, + ) + origin_connect = conn._connect + conn._connect = mock.Mock() + + def mock_connect(): + # connect only on the last retry + if conn._connect.call_count <= 2: + raise socket.timeout + else: + return origin_connect() + + conn._connect.side_effect = mock_connect + conn.connect() + assert conn._connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 + conn.disconnect()