Skip to content

Commit af18ae4

Browse files
committed
fix: statement splitter
1 parent d03c295 commit af18ae4

File tree

5 files changed

+132
-11
lines changed

5 files changed

+132
-11
lines changed

crates/pgt_statement_splitter/src/lib.rs

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,32 @@ mod tests {
6262
}
6363

6464
impl Tester {
65+
fn assert_single_statement(&self) -> &Self {
66+
assert_eq!(
67+
self.result.ranges.len(),
68+
1,
69+
"Expected a single statement for input {}, got {}: {:?}",
70+
self.input,
71+
self.result.ranges.len(),
72+
self.result
73+
.ranges
74+
.iter()
75+
.map(|r| &self.input[*r])
76+
.collect::<Vec<_>>()
77+
);
78+
self
79+
}
80+
81+
fn assert_no_errors(&self) -> &Self {
82+
assert!(
83+
self.result.errors.is_empty(),
84+
"Expected no errors, got {}: {:?}",
85+
self.result.errors.len(),
86+
self.result.errors
87+
);
88+
self
89+
}
90+
6591
fn expect_statements(&self, expected: Vec<&str>) -> &Self {
6692
assert_eq!(
6793
self.result.ranges.len(),
@@ -114,6 +140,16 @@ mod tests {
114140
);
115141
}
116142

143+
#[test]
144+
fn test_for_no_key_update() {
145+
Tester::from(
146+
"SELECT 1 FROM assessments AS a WHERE a.id = $assessment_id FOR NO KEY UPDATE;",
147+
)
148+
.expect_statements(vec![
149+
"SELECT 1 FROM assessments AS a WHERE a.id = $assessment_id FOR NO KEY UPDATE;",
150+
]);
151+
}
152+
117153
#[test]
118154
fn test_crash_eof() {
119155
Tester::from("CREATE INDEX \"idx_analytics_read_ratio\" ON \"public\".\"message\" USING \"btree\" (\"inbox_id\", \"timestamp\") INCLUDE (\"status\") WHERE (\"is_inbound\" = false and channel_type not in ('postal'', 'sms'));")
@@ -241,19 +277,52 @@ mod tests {
241277
}
242278

243279
#[test]
244-
fn trigger_instead_of() {
280+
fn with_recursive() {
245281
Tester::from(
246-
"CREATE OR REPLACE TRIGGER my_trigger
247-
INSTEAD OF INSERT ON my_table
248-
FOR EACH ROW
249-
EXECUTE FUNCTION my_table_trigger_fn();",
282+
"
283+
WITH RECURSIVE
284+
template_questions AS (
285+
-- non-recursive term that finds the ID of the template question (if any) for question_id
286+
SELECT
287+
tq.id,
288+
tq.qid,
289+
tq.course_id,
290+
tq.template_directory
291+
FROM
292+
questions AS q
293+
JOIN questions AS tq ON (
294+
tq.qid = q.template_directory
295+
AND tq.course_id = q.course_id
296+
)
297+
WHERE
298+
q.id = $question_id
299+
AND tq.deleted_at IS NULL
300+
-- required UNION for a recursive WITH statement
301+
UNION
302+
-- recursive term that references template_questions again
303+
SELECT
304+
tq.id,
305+
tq.qid,
306+
tq.course_id,
307+
tq.template_directory
308+
FROM
309+
template_questions AS q
310+
JOIN questions AS tq ON (
311+
tq.qid = q.template_directory
312+
AND tq.course_id = q.course_id
313+
)
314+
WHERE
315+
tq.deleted_at IS NULL
316+
)
317+
SELECT
318+
id
319+
FROM
320+
template_questions
321+
LIMIT
322+
100;",
250323
)
251-
.expect_statements(vec![
252-
"CREATE OR REPLACE TRIGGER my_trigger
253-
INSTEAD OF INSERT ON my_table
254-
FOR EACH ROW
255-
EXECUTE FUNCTION my_table_trigger_fn();",
256-
]);
324+
.assert_single_statement()
325+
.assert_no_errors();
257326
}
258327

259328
#[test]

crates/pgt_statement_splitter/src/splitter.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ impl<'a> Splitter<'a> {
102102
self.lexed.kind(self.current_pos)
103103
}
104104

105+
fn eat(&mut self, kind: SyntaxKind) -> bool {
106+
if self.current() == kind {
107+
self.advance();
108+
true
109+
} else {
110+
false
111+
}
112+
}
113+
105114
fn kind(&self, idx: usize) -> SyntaxKind {
106115
self.lexed.kind(idx)
107116
}

crates/pgt_statement_splitter/src/splitter/common.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ pub(crate) fn unknown(p: &mut Splitter, exclude: &[SyntaxKind]) {
224224
SyntaxKind::COMMA,
225225
// Do update in INSERT stmt
226226
SyntaxKind::DO_KW,
227+
// FOR NO KEY UPDATE
228+
SyntaxKind::KEY_KW,
227229
]
228230
.iter()
229231
.all(|x| Some(x) != prev.as_ref())

crates/pgt_statement_splitter/src/splitter/dml.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::{
77

88
pub(crate) fn cte(p: &mut Splitter) {
99
p.expect(SyntaxKind::WITH_KW);
10+
p.eat(SyntaxKind::RECURSIVE_KW);
1011

1112
loop {
1213
p.expect(SyntaxKind::IDENT);

crates/pgt_workspace/src/workspace/server.tests.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,46 @@ async fn test_positional_params(test_db: PgPool) {
386386
assert_eq!(diagnostics.len(), 0, "Expected no diagnostic");
387387
}
388388

389+
#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
390+
async fn test_named_params(_test_db: PgPool) {
391+
let conf = PartialConfiguration::init();
392+
393+
let workspace = get_test_workspace(Some(conf)).expect("Unable to create test workspace");
394+
395+
let path = PgTPath::new("test.sql");
396+
397+
let content = r#"
398+
SELECT
399+
1
400+
FROM
401+
assessments AS a
402+
WHERE
403+
a.id = $assessment_id
404+
FOR NO KEY UPDATE;
405+
"#;
406+
407+
workspace
408+
.open_file(OpenFileParams {
409+
path: path.clone(),
410+
content: content.into(),
411+
version: 1,
412+
})
413+
.expect("Unable to open test file");
414+
415+
let diagnostics = workspace
416+
.pull_diagnostics(crate::workspace::PullDiagnosticsParams {
417+
path: path.clone(),
418+
categories: RuleCategories::all(),
419+
max_diagnostics: 100,
420+
only: vec![],
421+
skip: vec![],
422+
})
423+
.expect("Unable to pull diagnostics")
424+
.diagnostics;
425+
426+
assert_eq!(diagnostics.len(), 0, "Expected no diagnostic");
427+
}
428+
389429
#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
390430
async fn test_cstyle_comments(test_db: PgPool) {
391431
let mut conf = PartialConfiguration::init();

0 commit comments

Comments
 (0)