From bb468cbcaf3912b0093178036505587ac0f0158b Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Mon, 9 Nov 2015 16:03:47 +0200 Subject: [PATCH] Improve parsing of INSERT INTO statements to support @@identity Patch by Christian Ullrich, some extra comments and regression test by me. --- execute.c | 112 ++++++++++++++++++++++++++----------- test/expected/identity.out | 17 ++++++ test/src/identity-test.c | 90 +++++++++++++++++++++++++++++ test/tests | 1 + 4 files changed, 188 insertions(+), 32 deletions(-) create mode 100644 test/expected/identity.out create mode 100644 test/src/identity-test.c diff --git a/execute.c b/execute.c index 2eeffce..111230b 100644 --- a/execute.c +++ b/execute.c @@ -33,6 +33,8 @@ #include "lobj.h" #include "pgapifunc.h" +static const char *next_name_token(const char *s, size_t *len); + /*extern GLOBAL_VALUES globals;*/ @@ -732,16 +734,22 @@ cleanup: return ret; } +/* + * Given a SQL statement, see if it is an INSERT INTO statement and extract + * the name of the table (with schema) of the table that was inserted to. + * (It is needed to resolve any @@identity references in the future.) + */ void SC_setInsertedTable(StatementClass *stmt, RETCODE retval) { const char *cmd = stmt->statement, *ptr; ConnectionClass *conn; size_t len; + const char *token = NULL; if (STMT_TYPE_INSERT != stmt->statement_type) return; - if (SQL_NEED_DATA == retval) + if (!SQL_SUCCEEDED(retval)) return; conn = SC_get_conn(stmt); #ifdef NOT_USED /* give up the use of lastval() */ @@ -750,6 +758,15 @@ SC_setInsertedTable(StatementClass *stmt, RETCODE retval) #endif /* NOT_USED */ /*if (!CC_fake_mss(conn)) return;*/ + + /* + * Parse a statement that was just executed. If it looks like an INSERT INTO + * statement, try to extract the table name (and schema) of the table that + * we inserted into. + * + * This is by no means fool-proof, we don't implement the whole backend + * lexer and grammar here, but should handle most simple INSERT statements. + */ while (isspace((UCHAR) *cmd)) cmd++; if (!*cmd) return; @@ -769,41 +786,72 @@ SC_setInsertedTable(StatementClass *stmt, RETCODE retval) return; NULL_THE_NAME(conn->schemaIns); NULL_THE_NAME(conn->tableIns); - if (!SQL_SUCCEEDED(retval)) - return; - while (TRUE) - { - if (IDENTIFIER_QUOTE == *cmd) - { - if (ptr = strchr(cmd + 1, IDENTIFIER_QUOTE), NULL == ptr) /* syntax error */ - { - NULL_THE_NAME(conn->schemaIns); - NULL_THE_NAME(conn->tableIns); - break; - } - len = ptr - cmd - 1; - cmd++; - ptr++; - } - else - { - if (ptr = strchr(cmd + 1, '.'), NULL != ptr) - len = ptr - cmd; + + len = 0; + token = next_name_token(cmd, &len); + if (token && *token == IDENTIFIER_QUOTE) + STRN_TO_NAME(conn->tableIns, token + 1, len - 2); + else + STRN_TO_NAME(conn->tableIns, token, len); + token = next_name_token(token, &len); + if (token && *token == '.') + { + token = next_name_token(token, &len); + if (token) { + if (NAME_IS_VALID(conn->tableIns)) + MOVE_NAME(conn->schemaIns, conn->tableIns); + if (*token == IDENTIFIER_QUOTE) + STRN_TO_NAME(conn->tableIns, token + 1, len - 2); else - { - ptr = cmd; - while (*ptr && !isspace((UCHAR) *ptr)) ptr++; - len = ptr - cmd; - } + STRN_TO_NAME(conn->tableIns, token, len); } - if (NAME_IS_VALID(conn->tableIns)) - MOVE_NAME(conn->schemaIns, conn->tableIns); - STRN_TO_NAME(conn->tableIns, cmd, len); - if ('.' == *ptr) - cmd = ptr + 1; - else + } + + if (!NAME_IS_VALID(conn->tableIns)) + NULL_THE_NAME(conn->schemaIns); +} + +/* + * Returns the next token from a qualified or unqualified name. + * + * s is the previous token. On entry, *len is the length of the previous + * token; on return, it is the length of the next one. If the token is + * quoted, the quotes are included in the result. If no valid token is found, + * NULL is returned. + */ +static const char * +next_name_token(const char *s, size_t *len) +{ + const char *p; + + s += *len; + while (*s && isspace(*s)) ++s; + + switch (*s) { + case '\0': + break; + case '.': + *len = 1; + return s; + case IDENTIFIER_QUOTE: + p = strchr(s + 1, IDENTIFIER_QUOTE); + if (p) { + *len = p - s + 1; + return s; + } + break; + default: + p = s; + while (*p && !isspace(*p) && *p != '.') ++p; + if (p) { + *len = p - s; + return s; + } break; } + + *len = 0; + return NULL; } /* Execute a prepared SQL statement */ diff --git a/test/expected/identity.out b/test/expected/identity.out new file mode 100644 index 0000000..fe3369b --- /dev/null +++ b/test/expected/identity.out @@ -0,0 +1,17 @@ +connected +# of rows inserted: 1 +Result set: +1 +Result set: +1 simple +# of rows inserted: 1 +Result set: +1 +Result set: +1 value.with.dots +# of rows inserted: 1 +Result set: +1 +Result set: +1 space in table name +disconnecting diff --git a/test/src/identity-test.c b/test/src/identity-test.c new file mode 100644 index 0000000..642e9fc --- /dev/null +++ b/test/src/identity-test.c @@ -0,0 +1,90 @@ +/* + * Test @@identity + */ +#include +#include + +#include "common.h" + +static HSTMT hstmt = SQL_NULL_HSTMT; + +/* + * The driver has to parse the insert statement to extract the table name, + * so it might or might not work depending on the table name. This function + * tests it with a table whose name is given as argument. + * + * NOTE: The table name is injected directly into the SQL statement. It should + * already contain any escaping or quoting needed! + */ +void +test_identity_col(const char *createsql, const char *insertsql, const char *verifysql) +{ + SQLRETURN rc; + char sql[1000]; + SQLLEN rowcount; + + /* Create the test table */ + rc = SQLExecDirect(hstmt, (SQLCHAR *) createsql, SQL_NTS); + CHECK_STMT_RESULT(rc, "SQLExecDirect failed while creating temp table", hstmt); + + /* Insert some rows using the statement given by the caller */ + rc = SQLExecDirect(hstmt, (SQLCHAR *) insertsql, SQL_NTS); + CHECK_STMT_RESULT(rc, "SQLExecDirect failed", hstmt); + + rc = SQLRowCount(hstmt, &rowcount); + CHECK_STMT_RESULT(rc, "SQLRowCount failed", hstmt); + printf("# of rows inserted: %d\n", (int) rowcount); + + rc = SQLExecDirect(hstmt, (SQLCHAR *) "SELECT @@IDENTITY AS last_insert", SQL_NTS); + CHECK_STMT_RESULT(rc, "SQLExecDirect failed", hstmt); + + /* Fetch result */ + print_result(hstmt); + + rc = SQLFreeStmt(hstmt, SQL_CLOSE); + CHECK_STMT_RESULT(rc, "SQLFreeStmt failed", hstmt); + + /* Print the contents of the table after the updates to verify */ + rc = SQLExecDirect(hstmt, (SQLCHAR *) verifysql, SQL_NTS); + CHECK_STMT_RESULT(rc, "SQLExecDirect failed", hstmt); + print_result(hstmt); + + rc = SQLFreeStmt(hstmt, SQL_CLOSE); + CHECK_STMT_RESULT(rc, "SQLFreeStmt failed", hstmt); +} + + +int main(int argc, char **argv) +{ + SQLRETURN rc; + SQLCHAR *sql; + + test_connect(); + + rc = SQLAllocHandle(SQL_HANDLE_STMT, conn, &hstmt); + if (!SQL_SUCCEEDED(rc)) + { + print_diag("failed to allocate stmt handle", SQL_HANDLE_DBC, conn); + exit(1); + } + + test_identity_col( + "CREATE TEMPORARY TABLE tmptable (id serial primary key, t text)", + "INSERT INTO tmptable (t) VALUES ('simple')", + "SELECT * FROM tmptable"); + + test_identity_col( + "CREATE TEMPORARY TABLE tmptable2 (id serial primary key, t text)", + "INSERT INTO tmptable2 (t) VALUES ('value.with.dots')", + "SELECT * FROM tmptable2"); + + test_identity_col( + "CREATE TEMPORARY TABLE \"tmp table\" (id serial primary key, t text)", + "INSERT INTO \"tmp table\" (t) VALUES ('space in table name')", + "SELECT * FROM \"tmp table\""); + + /* Clean up */ + test_disconnect(); + + return 0; +} diff --git a/test/tests b/test/tests index e908771..eb5fbe9 100644 --- a/test/tests +++ b/test/tests @@ -21,6 +21,7 @@ TESTBINS = src/connect-test \ src/params-test \ src/param-conversions-test \ src/parse-test \ + src/identity-test \ src/notice-test \ src/arraybinding-test \ src/insertreturning-test \ -- 2.39.5