Skip to content

Commit 9432c1d

Browse files
authored
fix: statement splitter (#496)
fixes #494
1 parent f02b57e commit 9432c1d

File tree

5 files changed

+144
-0
lines changed

5 files changed

+144
-0
lines changed

crates/pgt_statement_splitter/src/lib.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,32 @@ mod tests {
6262
}
6363

6464
impl Tester {
65+
fn assert_single_statement(&self) -> &Self {
66+
assert_eq!(
67+
self.result.ranges.len(),
68+
1,
69+
"Expected a single statement for input {}, got {}: {:?}",
70+
self.input,
71+
self.result.ranges.len(),
72+
self.result
73+
.ranges
74+
.iter()
75+
.map(|r| &self.input[*r])
76+
.collect::<Vec<_>>()
77+
);
78+
self
79+
}
80+
81+
fn assert_no_errors(&self) -> &Self {
82+
assert!(
83+
self.result.errors.is_empty(),
84+
"Expected no errors, got {}: {:?}",
85+
self.result.errors.len(),
86+
self.result.errors
87+
);
88+
self
89+
}
90+
6591
fn expect_statements(&self, expected: Vec<&str>) -> &Self {
6692
assert_eq!(
6793
self.result.ranges.len(),
@@ -114,6 +140,16 @@ mod tests {
114140
);
115141
}
116142

143+
#[test]
144+
fn test_for_no_key_update() {
145+
Tester::from(
146+
"SELECT 1 FROM assessments AS a WHERE a.id = $assessment_id FOR NO KEY UPDATE;",
147+
)
148+
.expect_statements(vec![
149+
"SELECT 1 FROM assessments AS a WHERE a.id = $assessment_id FOR NO KEY UPDATE;",
150+
]);
151+
}
152+
117153
#[test]
118154
fn test_crash_eof() {
119155
Tester::from("CREATE INDEX \"idx_analytics_read_ratio\" ON \"public\".\"message\" USING \"btree\" (\"inbox_id\", \"timestamp\") INCLUDE (\"status\") WHERE (\"is_inbound\" = false and channel_type not in ('postal'', 'sms'));")
@@ -256,6 +292,55 @@ mod tests {
256292
]);
257293
}
258294

295+
#[test]
296+
fn with_recursive() {
297+
Tester::from(
298+
"
299+
WITH RECURSIVE
300+
template_questions AS (
301+
-- non-recursive term that finds the ID of the template question (if any) for question_id
302+
SELECT
303+
tq.id,
304+
tq.qid,
305+
tq.course_id,
306+
tq.template_directory
307+
FROM
308+
questions AS q
309+
JOIN questions AS tq ON (
310+
tq.qid = q.template_directory
311+
AND tq.course_id = q.course_id
312+
)
313+
WHERE
314+
q.id = $question_id
315+
AND tq.deleted_at IS NULL
316+
-- required UNION for a recursive WITH statement
317+
UNION
318+
-- recursive term that references template_questions again
319+
SELECT
320+
tq.id,
321+
tq.qid,
322+
tq.course_id,
323+
tq.template_directory
324+
FROM
325+
template_questions AS q
326+
JOIN questions AS tq ON (
327+
tq.qid = q.template_directory
328+
AND tq.course_id = q.course_id
329+
)
330+
WHERE
331+
tq.deleted_at IS NULL
332+
)
333+
SELECT
334+
id
335+
FROM
336+
template_questions
337+
LIMIT
338+
100;",
339+
)
340+
.assert_single_statement()
341+
.assert_no_errors();
342+
}
343+
259344
#[test]
260345
fn with_check() {
261346
Tester::from("create policy employee_insert on journey_execution for insert to authenticated with check ((select private.organisation_id()) = organisation_id);")

crates/pgt_statement_splitter/src/splitter.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ impl<'a> Splitter<'a> {
102102
self.lexed.kind(self.current_pos)
103103
}
104104

105+
fn eat(&mut self, kind: SyntaxKind) -> bool {
106+
if self.current() == kind {
107+
self.advance();
108+
true
109+
} else {
110+
false
111+
}
112+
}
113+
105114
fn kind(&self, idx: usize) -> SyntaxKind {
106115
self.lexed.kind(idx)
107116
}

crates/pgt_statement_splitter/src/splitter/common.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ pub(crate) fn unknown(p: &mut Splitter, exclude: &[SyntaxKind]) {
224224
SyntaxKind::COMMA,
225225
// Do update in INSERT stmt
226226
SyntaxKind::DO_KW,
227+
// FOR NO KEY UPDATE
228+
SyntaxKind::KEY_KW,
227229
]
228230
.iter()
229231
.all(|x| Some(x) != prev.as_ref())

crates/pgt_statement_splitter/src/splitter/dml.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::{
77

88
pub(crate) fn cte(p: &mut Splitter) {
99
p.expect(SyntaxKind::WITH_KW);
10+
p.eat(SyntaxKind::RECURSIVE_KW);
1011

1112
loop {
1213
p.expect(SyntaxKind::IDENT);

crates/pgt_workspace/src/workspace/server.tests.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,53 @@ async fn test_disable_typecheck(test_db: PgPool) {
579579
);
580580
}
581581

582+
#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
583+
async fn test_named_params(_test_db: PgPool) {
584+
let conf = PartialConfiguration::init();
585+
586+
let workspace = get_test_workspace(Some(conf)).expect("Unable to create test workspace");
587+
588+
let path = PgTPath::new("test.sql");
589+
590+
let content = r#"
591+
SELECT
592+
1
593+
FROM
594+
assessments AS a
595+
WHERE
596+
a.id = $assessment_id
597+
FOR NO KEY UPDATE;
598+
"#;
599+
600+
workspace
601+
.open_file(OpenFileParams {
602+
path: path.clone(),
603+
content: content.into(),
604+
version: 1,
605+
})
606+
.expect("Unable to open test file");
607+
608+
let diagnostics = workspace
609+
.pull_diagnostics(crate::workspace::PullDiagnosticsParams {
610+
path: path.clone(),
611+
categories: RuleCategories::all(),
612+
max_diagnostics: 100,
613+
only: vec![],
614+
skip: vec![],
615+
})
616+
.expect("Unable to pull diagnostics")
617+
.diagnostics;
618+
619+
assert_eq!(
620+
diagnostics
621+
.iter()
622+
.filter(|d| d.category().is_some_and(|c| c.name() == "syntax"))
623+
.count(),
624+
0,
625+
"Expected no syntax diagnostic"
626+
);
627+
}
628+
582629
#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
583630
async fn test_cstyle_comments(test_db: PgPool) {
584631
let mut conf = PartialConfiguration::init();

0 commit comments

Comments
 (0)