Skip to content

Commit 583aa98

Browse files
feat: Enhance query handling with IN operator support and related tests
- Updated `CrudRepositoryImplementationService` to support IN operations in query filtering, including validation for parameter types and handling of empty lists. - Modified `_Query` class to include `field_operations` for mapping field names to their respective operations. - Enhanced `_MetodQueryBuilder` to parse method names with IN operations and generate appropriate query structures. - Added comprehensive unit tests for IN operator functionality, including cases for single fields, multiple fields with AND/OR conditions, and error handling for invalid types and empty lists. - Updated existing tests to reflect changes in query structure and ensure correct SQL statement generation.
1 parent 926398e commit 583aa98

File tree

4 files changed

+157
-15
lines changed

4 files changed

+157
-15
lines changed

py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,34 @@ def _get_sql_statement(
133133
parsed_query: _Query,
134134
params: dict[str, Any],
135135
) -> SelectOfScalar[PySpringModelT]:
136-
filter_condition_stack: list[ColumnElement[bool]] = [
137-
getattr(model_type, field) == params[field]
138-
for field in parsed_query.required_fields
139-
]
136+
filter_condition_stack: list[ColumnElement[bool]] = []
137+
138+
for field in parsed_query.required_fields:
139+
column = getattr(model_type, field)
140+
param_value = params[field]
141+
142+
# Check if this field has a specific operation
143+
if field in parsed_query.field_operations:
144+
operation = parsed_query.field_operations[field]
145+
if operation == "in":
146+
# Handle IN operation
147+
if not isinstance(param_value, (list, tuple, set)):
148+
raise ValueError(f"Parameter for IN operation must be a collection (list, tuple, or set), got {type(param_value)}")
149+
150+
# Handle empty list case - return no results
151+
if len(param_value) == 0:
152+
# Create a condition that's always false
153+
filter_condition_stack.append(column == None)
154+
continue
155+
156+
filter_condition_stack.append(column.in_(param_value))
157+
else:
158+
# Default to equality for unknown operations
159+
filter_condition_stack.append(column == param_value)
160+
else:
161+
# Default equality operation
162+
filter_condition_stack.append(column == param_value)
163+
140164
for notation in parsed_query.notations:
141165
right_condition = filter_condition_stack.pop(0)
142166
left_condition = filter_condition_stack.pop(0)

py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/method_query_builder.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from typing import Dict, Any
23

34
from pydantic import BaseModel
45

@@ -9,12 +10,14 @@ class _Query(BaseModel):
910
- `conditions`: A list of string conditions that will be used to filter the query.
1011
- `is_one_result`: A boolean indicating whether the query should return a single result or a list of results.
1112
- `required_fields`: A list of string field names that should be included in the query result.
13+
- `field_operations`: A dictionary mapping field names to their operations (e.g., "in" for IN operator).
1214
"""
1315

1416
raw_query_list: list[str]
1517
is_one_result: bool
1618
notations: list[str]
1719
required_fields: list[str]
20+
field_operations: Dict[str, str] = {}
1821

1922

2023
class _MetodQueryBuilder:
@@ -36,6 +39,7 @@ def parse_query(self) -> _Query:
3639
Example:
3740
- 'find_by_name_and_age' -> Query(raw_query_list=['name', '_and_', 'age'], is_one_result=True, required_fields=['name', 'age'])
3841
- 'find_all_by_name_or_age' -> Query(raw_query_list=['name', '_or_', 'age'], is_one_result=False, required_fields=['name', 'age'])
42+
- 'find_by_status_in' -> Query(raw_query_list=['status'], is_one_result=True, required_fields=['status'], field_operations={'status': 'in'})
3943
"""
4044
is_one = False
4145
pattern = ""
@@ -53,7 +57,6 @@ def parse_query(self) -> _Query:
5357
if len(pattern) == 0:
5458
raise ValueError(f"Method name must start with 'get_by', 'find_by', 'find_all_by', or 'get_all_by': {self.method_name}")
5559

56-
5760
match = re.match(pattern, self.method_name)
5861
if not match:
5962
raise ValueError(f"Invalid method name: {self.method_name}")
@@ -62,13 +65,26 @@ def parse_query(self) -> _Query:
6265
# Split fields by '_and_' or '_or_' and keep logical operators
6366
raw_query_list = re.split(r"(_and_|_or_)", raw_query)
6467

68+
# Extract required fields and detect operations
69+
required_fields = []
70+
field_operations = {}
71+
72+
for field in raw_query_list:
73+
if field not in ["_and_", "_or_"]:
74+
# Check for IN operation
75+
if field.endswith("_in"):
76+
base_field = field[:-3] # Remove "_in" suffix
77+
required_fields.append(base_field)
78+
field_operations[base_field] = "in"
79+
else:
80+
required_fields.append(field)
81+
6582
return _Query(
6683
raw_query_list=raw_query_list,
6784
is_one_result=is_one,
68-
required_fields=[
69-
field for field in raw_query_list if field not in ["_and_", "_or_"]
70-
],
85+
required_fields=required_fields,
7186
notations=[
7287
notation for notation in raw_query_list if notation in ["_and_", "_or_"]
7388
],
89+
field_operations=field_operations,
7490
)

tests/test_crud_repository_implementation_service.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@ class User(PySpringModel, table=True):
1515
id: int = Field(default=None, primary_key=True)
1616
name: str
1717
email: str
18+
status: str = Field(default="active")
19+
category: str = Field(default="general")
1820

1921
class UserView(BaseModel):
2022
name: str
2123

2224
class UserRepository(CrudRepository[int,User]):
2325
def find_by_name(self, name: str) -> User: ...
26+
def find_all_by_status_in(self, status: list[str]) -> list[User]: ...
27+
def find_all_by_id_in_and_name(self, id: list[int], name: str) -> list[User]: ...
28+
def find_all_by_status_in_or_category_in(self, status: list[str], category: list[str]) -> list[User]: ...
2429
@Query("SELECT * FROM user WHERE name = {name}")
2530
def query_uery_by_name(self, name: str) -> User: ...
2631

@@ -53,17 +58,17 @@ def implementation_service(self) -> CrudRepositoryImplementationService:
5358
def test_query_single_annotation(self, implementation_service: CrudRepositoryImplementationService):
5459
parsed_query = _MetodQueryBuilder("find_by_name").parse_query()
5560
statement = implementation_service._get_sql_statement(User, parsed_query, {"name": "John Doe"})
56-
assert str(statement).replace("\n", "") == 'SELECT "user".id, "user".name, "user".email FROM "user" WHERE "user".name = :name_1'
61+
assert str(statement).replace("\n", "") == 'SELECT "user".id, "user".name, "user".email, "user".status, "user".category FROM "user" WHERE "user".name = :name_1'
5762

5863
def test_query_and_annotation(self, implementation_service: CrudRepositoryImplementationService):
5964
parsed_query = _MetodQueryBuilder("find_by_name_and_email").parse_query()
6065
statement = implementation_service._get_sql_statement(User, parsed_query, {"name": "John Doe", "email": "john@example.com"})
61-
assert str(statement).replace("\n", "") == 'SELECT "user".id, "user".name, "user".email FROM "user" WHERE "user".email = :email_1 AND "user".name = :name_1'
66+
assert str(statement).replace("\n", "") == 'SELECT "user".id, "user".name, "user".email, "user".status, "user".category FROM "user" WHERE "user".email = :email_1 AND "user".name = :name_1'
6267

6368
def test_query_or_annotation(self, implementation_service: CrudRepositoryImplementationService):
6469
parsed_query = _MetodQueryBuilder("find_by_name_or_email").parse_query()
6570
statement = implementation_service._get_sql_statement(User, parsed_query, {"name": "John Doe", "email": "john@example.com"})
66-
assert str(statement).replace("\n", "") == 'SELECT "user".id, "user".name, "user".email FROM "user" WHERE "user".email = :email_1 OR "user".name = :name_1'
71+
assert str(statement).replace("\n", "") == 'SELECT "user".id, "user".name, "user".email, "user".status, "user".category FROM "user" WHERE "user".email = :email_1 OR "user".name = :name_1'
6772

6873
def test_did_implement_query(self, user_repository: UserRepository, implementation_service: CrudRepositoryImplementationService):
6974
user = User(name="John Doe", email="john@example.com")
@@ -103,9 +108,68 @@ def test_query_user_view_by_name_invalid_argument_type(self, user_repository: Us
103108
with pytest.raises(ValueError, match=".*"):
104109
user_repository.query_user_view_by_name(name=None) # `name` should not be None # type: ignore
105110

111+
def test_in_operator_single_field(self, implementation_service: CrudRepositoryImplementationService):
112+
parsed_query = _MetodQueryBuilder("find_by_status_in").parse_query()
113+
statement = implementation_service._get_sql_statement(User, parsed_query, {"status": ["active", "pending"]})
114+
assert "IN" in str(statement).upper()
115+
assert "status" in str(statement).lower()
116+
117+
def test_in_operator_with_and(self, implementation_service: CrudRepositoryImplementationService):
118+
parsed_query = _MetodQueryBuilder("find_by_id_in_and_name").parse_query()
119+
statement = implementation_service._get_sql_statement(User, parsed_query, {"id": [1, 2, 3], "name": "John"})
120+
assert "IN" in str(statement).upper()
121+
assert "AND" in str(statement).upper()
122+
123+
def test_in_operator_with_or(self, implementation_service: CrudRepositoryImplementationService):
124+
parsed_query = _MetodQueryBuilder("find_by_status_in_or_category_in").parse_query()
125+
statement = implementation_service._get_sql_statement(User, parsed_query, {"status": ["active"], "category": ["premium"]})
126+
assert "IN" in str(statement).upper()
127+
assert "OR" in str(statement).upper()
128+
129+
def test_in_operator_empty_list(self, implementation_service: CrudRepositoryImplementationService):
130+
parsed_query = _MetodQueryBuilder("find_by_status_in").parse_query()
131+
statement = implementation_service._get_sql_statement(User, parsed_query, {"status": []})
132+
# Empty list should result in a condition that's always false
133+
assert "IS NULL" in str(statement) or "= NULL" in str(statement)
134+
135+
def test_in_operator_invalid_type(self, implementation_service: CrudRepositoryImplementationService):
136+
parsed_query = _MetodQueryBuilder("find_by_status_in").parse_query()
137+
with pytest.raises(ValueError, match="Parameter for IN operation must be a collection"):
138+
implementation_service._get_sql_statement(User, parsed_query, {"status": "not_a_list"})
139+
140+
def test_in_operator_implementation(self, user_repository: UserRepository, implementation_service: CrudRepositoryImplementationService):
141+
# Create test users
142+
user1 = User(name="John", email="john@example.com", status="active", category="premium")
143+
user2 = User(name="Jane", email="jane@example.com", status="pending", category="premium")
144+
user3 = User(name="Bob", email="bob@example.com", status="active", category="basic")
106145

146+
user_repository.save(user1)
147+
user_repository.save(user2)
148+
user_repository.save(user3)
107149

108-
109-
110-
150+
# Implement the query
151+
implementation_service._implemenmt_query(user_repository.__class__)
152+
153+
# Test IN operator
154+
active_users = user_repository.find_all_by_status_in(status=["active"])
155+
assert len(active_users) == 2
156+
assert all(user.status == "active" for user in active_users)
157+
158+
# Test IN with AND
159+
premium_active_users = user_repository.find_all_by_id_in_and_name(id=[user1.id, user2.id], name="John")
160+
assert len(premium_active_users) == 1
161+
assert premium_active_users[0].name == "John"
162+
163+
# Test IN with OR
164+
active_or_premium = user_repository.find_all_by_status_in_or_category_in(status=["active"], category=["premium"])
165+
assert len(active_or_premium) == 3 # All users are either active or premium
111166

167+
def test_in_operator_empty_list_returns_no_results(self, user_repository: UserRepository, implementation_service: CrudRepositoryImplementationService):
168+
user = User(name="John", email="john@example.com", status="active")
169+
user_repository.save(user)
170+
171+
implementation_service._implemenmt_query(user_repository.__class__)
172+
173+
# Empty list should return no results
174+
results = user_repository.find_all_by_status_in(status=[])
175+
assert len(results) == 0

tests/test_method_query_builder.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,71 @@
55

66
class TestMetodQueryBuilder:
77
@pytest.mark.parametrize(
8-
"method_name, expected_raw_query_list, expected_is_one_result, expected_required_fields, expected_notations",
8+
"method_name, expected_raw_query_list, expected_is_one_result, expected_required_fields, expected_notations, expected_field_operations",
99
[
1010
(
1111
"get_by_name_and_age",
1212
["name", "_and_", "age"],
1313
True,
1414
["name", "age"],
1515
["_and_"],
16+
{},
1617
),
1718
(
1819
"find_by_name_or_age",
1920
["name", "_or_", "age"],
2021
True,
2122
["name", "age"],
2223
["_or_"],
24+
{},
2325
),
2426
(
2527
"find_all_by_name_and_age",
2628
["name", "_and_", "age"],
2729
False,
2830
["name", "age"],
2931
["_and_"],
32+
{},
3033
),
3134
(
3235
"get_all_by_city_or_country",
3336
["city", "_or_", "country"],
3437
False,
3538
["city", "country"],
3639
["_or_"],
40+
{},
41+
),
42+
(
43+
"find_by_status_in",
44+
["status_in"],
45+
True,
46+
["status"],
47+
[],
48+
{"status": "in"},
49+
),
50+
(
51+
"find_all_by_id_in",
52+
["id_in"],
53+
False,
54+
["id"],
55+
[],
56+
{"id": "in"},
57+
),
58+
(
59+
"find_by_status_in_and_name",
60+
["status_in", "_and_", "name"],
61+
True,
62+
["status", "name"],
63+
["_and_"],
64+
{"status": "in"},
65+
),
66+
(
67+
"find_by_status_in_or_category_in",
68+
["status_in", "_or_", "category_in"],
69+
True,
70+
["status", "category"],
71+
["_or_"],
72+
{"status": "in", "category": "in"},
3773
),
3874
],
3975
)
@@ -44,6 +80,7 @@ def test_parse_query(
4480
expected_is_one_result,
4581
expected_required_fields,
4682
expected_notations,
83+
expected_field_operations,
4784
):
4885
builder = _MetodQueryBuilder(method_name)
4986
query = builder.parse_query()
@@ -53,6 +90,7 @@ def test_parse_query(
5390
assert query.is_one_result == expected_is_one_result
5491
assert query.required_fields == expected_required_fields
5592
assert query.notations == expected_notations
93+
assert query.field_operations == expected_field_operations
5694

5795
def test_invalid_method_name(self):
5896
invalid_method_name = "invalid_method_name"

0 commit comments

Comments
 (0)