diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..d9d29a8b7 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,41 @@ +name: "CodeQL" + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + schedule: + - cron: "18 19 * * 1" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ go ] + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + queries: +security-and-quality + + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{ matrix.language }}" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..d45ed0fa9 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,110 @@ +name: test +on: + pull_request: + push: + workflow_dispatch: + +env: + MYSQL_TEST_USER: gotest + MYSQL_TEST_PASS: secret + MYSQL_TEST_ADDR: 127.0.0.1:3306 + MYSQL_TEST_CONCURRENT: 1 + +jobs: + list: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: list + id: set-matrix + run: | + import json + import os + go = [ + # Keep the most recent production release at the top + '1.20', + # Older production releases + '1.19', + '1.18', + '1.17', + '1.16', + '1.15', + '1.14', + '1.13', + ] + mysql = [ + '8.0', + '5.7', + '5.6', + 'mariadb-10.11', + 'mariadb-10.6', + 'mariadb-10.5', + 'mariadb-10.4', + 'mariadb-10.3', + ] + + includes = [] + # Go versions compatibility check + for v in go[1:]: + includes.append({'os': 'ubuntu-latest', 'go': v, 'mysql': mysql[0]}) + + matrix = { + # OS vs MySQL versions + 'os': [ 'ubuntu-latest', 'macos-latest', 'windows-latest' ], + 'go': [ go[0] ], + 'mysql': mysql, + + 'include': includes + } + output = json.dumps(matrix, separators=(',', ':')) + with open(os.environ["GITHUB_OUTPUT"], 'a', encoding="utf-8") as f: + f.write('matrix={0}\n'.format(output)) + shell: python + test: + needs: list + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.list.outputs.matrix) }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - uses: shogo82148/actions-setup-mysql@v1.15.0 + with: + mysql-version: ${{ matrix.mysql }} + user: ${{ env.MYSQL_TEST_USER }} + password: ${{ env.MYSQL_TEST_PASS }} + my-cnf: | + innodb_log_file_size=256MB + innodb_buffer_pool_size=512MB + max_allowed_packet=16MB + ; TestConcurrent fails if max_connections is too large + max_connections=50 + local_infile=1 + - name: setup database + run: | + mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;' + + - name: test + run: | + go test -v '-covermode=count' '-coverprofile=coverage.out' + + - name: Send coverage + uses: shogo82148/actions-goveralls@v1 + with: + path-to-profile: coverage.out + flag-name: ${{ runner.os }}-Go-${{ matrix.go }}-DB-${{ matrix.mysql }} + parallel: true + + # notifies that all test jobs are finished. + finish: + needs: test + if: always() + runs-on: ubuntu-latest + steps: + - uses: shogo82148/actions-goveralls@v1 + with: + parallel-finished: true diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 56fcf25f2..000000000 --- a/.travis.yml +++ /dev/null @@ -1,129 +0,0 @@ -sudo: false -language: go -go: - - 1.10.x - - 1.11.x - - 1.12.x - - 1.13.x - - master - -before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - -before_script: - - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf - - sudo service mysql restart - - .travis/wait_mysql.sh - - mysql -e 'create database gotest;' - -matrix: - include: - - env: DB=MYSQL8 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mysql:8.0 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MYSQL57 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mysql:5.7 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MARIA55 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mariadb:5.5 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MARIA10_1 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mariadb:10.1 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - os: osx - osx_image: xcode10.1 - addons: - homebrew: - packages: - - mysql - update: true - go: 1.12.x - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - before_script: - - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB\nlocal_infile=1" >> /usr/local/etc/my.cnf - - mysql.server start - - mysql -uroot -e 'CREATE USER gotest IDENTIFIED BY "secret"' - - mysql -uroot -e 'GRANT ALL ON *.* TO gotest' - - mysql -uroot -e 'create database gotest;' - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3306 - - export MYSQL_TEST_CONCURRENT=1 - -script: - - go test -v -covermode=count -coverprofile=coverage.out - - go vet ./... - - .travis/gofmt.sh -after_script: - - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci diff --git a/.travis/docker.cnf b/.travis/docker.cnf deleted file mode 100644 index e57754e5a..000000000 --- a/.travis/docker.cnf +++ /dev/null @@ -1,5 +0,0 @@ -[client] -user = gotest -password = secret -host = 127.0.0.1 -port = 3307 diff --git a/.travis/gofmt.sh b/.travis/gofmt.sh deleted file mode 100755 index 9bf0d1684..000000000 --- a/.travis/gofmt.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -set -ev - -# Only check for go1.10+ since the gofmt style changed -if [[ $(go version) =~ go1\.([0-9]+) ]] && ((${BASH_REMATCH[1]} >= 10)); then - test -z "$(gofmt -d -s . | tee /dev/stderr)" -fi diff --git a/.travis/wait_mysql.sh b/.travis/wait_mysql.sh deleted file mode 100755 index e87993e57..000000000 --- a/.travis/wait_mysql.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/sh -while : -do - if mysql -e 'select version()' 2>&1 | grep 'version()\|ERROR 2059 (HY000):'; then - break - fi - sleep 3 -done diff --git a/AUTHORS b/AUTHORS index ad5989800..fb1478c3b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -13,12 +13,17 @@ Aaron Hopkins Achille Roussel +Alex Snast Alexey Palazhchenko Andrew Reid +Animesh Ray Arne Hormann +Ariel Mashraki Asta Xie Bulat Gaifullin +Caine Jette Carlos Nieto +Chris Kirkland Chris Moos Craig Wilson Daniel Montoya @@ -41,6 +46,7 @@ Ilia Cimpoes INADA Naoki Jacek Szwec James Harr +Janek Vedock Jeff Hodges Jeffrey Charles Jerome Meyer @@ -52,14 +58,17 @@ Julien Schmidt Justin Li Justin Nuß Kamil Dziedzic +Kei Kamikawa Kevin Malachowski Kieron Woodhouse +Lance Tian Lennart Rudolph Leonardo YongUk Kim Linh Tran Tuan Lion Yang Luca Looz Lucas Liu +Lunny Xiao Luke Scott Maciej Zimnoch Michael Woolnough @@ -69,31 +78,42 @@ Olivier Mengué oscarzhao Paul Bonser Peter Schultz +Phil Porada Rebecca Chin Reed Allman Richard Wilkes Robert Russell Runrioter Wung +Samantha Frank +Santhosh Kumar Tekuri +Sho Iizuka +Sho Ikeda Shuode Li Simon J Mudd Soroush Pour Stan Putrya Stanley Gunawan Steven Hartland +Tan Jinhua <312841925 at qq.com> Thomas Wodarek Tim Ruffles Tom Jenkinson Vladimir Kovpak +Vladyslav Zhelezniak Xiangyu Hu Xiaobing Jiang Xiuming Chen +Xuehong Chan Zhenye Xie +Zhixin Wen +Ziheng Lyu # Organizations Barracuda Networks, Inc. Counting Ltd. DigitalOcean Inc. +dyves labs AG Facebook Inc. GitHub Inc. Google Inc. @@ -103,3 +123,4 @@ Multiplay Ltd. Percona LLC Pivotal Inc. Stripe Inc. +Zendesk Inc. diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cb97b38d..5166e4adb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,63 @@ +## Version 1.7.1 (2023-04-25) + +Changes: + + - bump actions/checkout@v3 and actions/setup-go@v3 (#1375) + - Add go1.20 and mariadb10.11 to the testing matrix (#1403) + - Increase default maxAllowedPacket size. (#1411) + +Bugfixes: + + - Use SET syntax as specified in the MySQL documentation (#1402) + + +## Version 1.7 (2022-11-29) + +Changes: + + - Drop support of Go 1.12 (#1211) + - Refactoring `(*textRows).readRow` in a more clear way (#1230) + - util: Reduce boundary check in escape functions. (#1316) + - enhancement for mysqlConn handleAuthResult (#1250) + +New Features: + + - support Is comparison on MySQLError (#1210) + - return unsigned in database type name when necessary (#1238) + - Add API to express like a --ssl-mode=PREFERRED MySQL client (#1370) + - Add SQLState to MySQLError (#1321) + +Bugfixes: + + - Fix parsing 0 year. (#1257) + + +## Version 1.6 (2021-04-01) + +Changes: + + - Migrate the CI service from travis-ci to GitHub Actions (#1176, #1183, #1190) + - `NullTime` is deprecated (#960, #1144) + - Reduce allocations when building SET command (#1111) + - Performance improvement for time formatting (#1118) + - Performance improvement for time parsing (#1098, #1113) + +New Features: + + - Implement `driver.Validator` interface (#1106, #1174) + - Support returning `uint64` from `Valuer` in `ConvertValue` (#1143) + - Add `json.RawMessage` for converter and prepared statement (#1059) + - Interpolate `json.RawMessage` as `string` (#1058) + - Implements `CheckNamedValue` (#1090) + +Bugfixes: + + - Stop rounding times (#1121, #1172) + - Put zero filler into the SSL handshake packet (#1066) + - Fix checking cancelled connections back into the connection pool (#1095) + - Fix remove last 0 byte for mysql_old_password when password is empty (#1133) + + ## Version 1.5 (2020-01-07) Changes: diff --git a/README.md b/README.md index d2627a41a..3b5d229aa 100644 --- a/README.md +++ b/README.md @@ -35,12 +35,12 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Supports queries larger than 16MB * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. * Intelligent `LONG DATA` handling in prepared statements - * Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support + * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation ## Requirements - * Go 1.10 or higher. We aim to support the 3 latest versions of Go. + * Go 1.13 or higher. We aim to support the 3 latest versions of Go. * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) --------------------------------------- @@ -56,15 +56,37 @@ Make sure [Git is installed](https://git-scm.com/downloads) on your machine and _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: + ```go -import "database/sql" -import _ "github.com/go-sql-driver/mysql" +import ( + "database/sql" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +// ... db, err := sql.Open("mysql", "user:password@/dbname") +if err != nil { + panic(err) +} +// See "Important settings" section. +db.SetConnMaxLifetime(time.Minute * 3) +db.SetMaxOpenConns(10) +db.SetMaxIdleConns(10) ``` [Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples"). +### Important settings + +`db.SetConnMaxLifetime()` is required to ensure connections are closed by the driver safely before connection is closed by MySQL server, OS, or other middlewares. Since some middlewares close idle connections by 5 minutes, we recommend timeout shorter than 5 minutes. This setting helps load balancing and changing system variables too. + +`db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server. + +`db.SetMaxIdleConns()` is recommended to be set same to `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed much more frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15. + ### DSN (Data Source Name) @@ -122,7 +144,7 @@ Valid Values: true, false Default: false ``` -`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. +`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files. [*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) ##### `allowCleartextPasswords` @@ -133,7 +155,18 @@ Valid Values: true, false Default: false ``` -`allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. +`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. + + +##### `allowFallbackToPlaintext` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`allowFallbackToPlaintext=true` acts like a `--ssl-mode=PREFERRED` MySQL client as described in [Command Options for Connecting to the Server](https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode) ##### `allowNativePasswords` @@ -230,7 +263,7 @@ Default: false If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. -*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are blacklisted as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* +*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* ##### `loc` @@ -249,10 +282,10 @@ Please keep in mind, that param values must be [url.QueryEscape](https://golang. ##### `maxAllowedPacket` ``` Type: decimal number -Default: 4194304 +Default: 64*1024*1024 ``` -Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. +Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. ##### `multiStatements` @@ -376,7 +409,7 @@ Rules: Examples: * `autocommit=1`: `SET autocommit=1` * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` - * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` + * [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'` #### Examples @@ -432,7 +465,7 @@ user:password@/ The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. ## `ColumnType` Support -This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `BIGINT`. ## `context.Context` Support Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. @@ -445,7 +478,7 @@ For this feature you need direct access to the package. Therefore you must chang import "github.com/go-sql-driver/mysql" ``` -Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). +Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. @@ -459,8 +492,6 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). -Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`. - ### Unicode support Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. @@ -477,7 +508,7 @@ To run the driver tests you may need to adjust the configuration. See the [Testi Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls). -See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CONTRIBUTING.md) for details. +See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details. --------------------------------------- @@ -498,4 +529,3 @@ Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). ![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") - diff --git a/atomic_bool.go b/atomic_bool.go new file mode 100644 index 000000000..1b7e19f3e --- /dev/null +++ b/atomic_bool.go @@ -0,0 +1,19 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build go1.19 +// +build go1.19 + +package mysql + +import "sync/atomic" + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +type atomicBool = atomic.Bool diff --git a/atomic_bool_go118.go b/atomic_bool_go118.go new file mode 100644 index 000000000..2e9a7f0b6 --- /dev/null +++ b/atomic_bool_go118.go @@ -0,0 +1,47 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !go1.19 +// +build !go1.19 + +package mysql + +import "sync/atomic" + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +// atomicBool is an implementation of atomic.Bool for older version of Go. +// it is a wrapper around uint32 for usage as a boolean value with +// atomic access. +type atomicBool struct { + _ noCopy + value uint32 +} + +// Load returns whether the current boolean value is true +func (ab *atomicBool) Load() bool { + return atomic.LoadUint32(&ab.value) > 0 +} + +// Store sets the value of the bool regardless of the previous value +func (ab *atomicBool) Store(value bool) { + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// Swap sets the value of the bool and returns the old value. +func (ab *atomicBool) Swap(value bool) bool { + if value { + return atomic.SwapUint32(&ab.value, 1) > 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} diff --git a/atomic_bool_test.go b/atomic_bool_test.go new file mode 100644 index 000000000..a3b4ea0e8 --- /dev/null +++ b/atomic_bool_test.go @@ -0,0 +1,71 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !go1.19 +// +build !go1.19 + +package mysql + +import ( + "testing" +) + +func TestAtomicBool(t *testing.T) { + var ab atomicBool + if ab.Load() { + t.Fatal("Expected value to be false") + } + + ab.Store(true) + if ab.value != 1 { + t.Fatal("Set(true) did not set value to 1") + } + if !ab.Load() { + t.Fatal("Expected value to be true") + } + + ab.Store(true) + if !ab.Load() { + t.Fatal("Expected value to be true") + } + + ab.Store(false) + if ab.value != 0 { + t.Fatal("Set(false) did not set value to 0") + } + if ab.Load() { + t.Fatal("Expected value to be false") + } + + ab.Store(false) + if ab.Load() { + t.Fatal("Expected value to be false") + } + if ab.Swap(false) { + t.Fatal("Expected the old value to be false") + } + if ab.Swap(true) { + t.Fatal("Expected the old value to be false") + } + if !ab.Load() { + t.Fatal("Expected value to be true") + } + + ab.Store(true) + if !ab.Load() { + t.Fatal("Expected value to be true") + } + if !ab.Swap(true) { + t.Fatal("Expected the old value to be true") + } + if !ab.Swap(false) { + t.Fatal("Expected the old value to be true") + } + if ab.Load() { + t.Fatal("Expected value to be false") + } +} diff --git a/auth.go b/auth.go index fec7040d4..1ff203e57 100644 --- a/auth.go +++ b/auth.go @@ -15,6 +15,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "fmt" "sync" ) @@ -32,27 +33,26 @@ var ( // Note: The provided rsa.PublicKey instance is exclusively owned by the driver // after registering it and may not be modified. // -// data, err := ioutil.ReadFile("mykey.pem") -// if err != nil { -// log.Fatal(err) -// } +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } // -// block, _ := pem.Decode(data) -// if block == nil || block.Type != "PUBLIC KEY" { -// log.Fatal("failed to decode PEM block containing public key") -// } +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } // -// pub, err := x509.ParsePKIXPublicKey(block.Bytes) -// if err != nil { -// log.Fatal(err) -// } -// -// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { -// mysql.RegisterServerPubKey("mykey", rsaPubKey) -// } else { -// log.Fatal("not a RSA public key") -// } +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } // +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { serverPubKeyLock.Lock() if serverPubKeyRegistry == nil { @@ -136,10 +136,6 @@ func pwHash(password []byte) (result [2]uint32) { // Hash password using insecure pre 4.1 method func scrambleOldPassword(scramble []byte, password string) []byte { - if len(password) == 0 { - return nil - } - scramble = scramble[:8] hashPw := pwHash([]byte(password)) @@ -247,6 +243,9 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { if !mc.cfg.AllowOldPasswords { return nil, ErrOldPassword } + if len(mc.cfg.Passwd) == 0 { + return nil, nil + } // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 @@ -274,7 +273,9 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { if len(mc.cfg.Passwd) == 0 { return []byte{0}, nil } - if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // unlike caching_sha2_password, sha256_password does not accept + // cleartext password on unix transport. + if mc.cfg.TLS != nil { // write cleartext auth packet return append([]byte(mc.cfg.Passwd), 0), nil } @@ -350,7 +351,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } case cachingSha2PasswordPerformFullAuthentication: - if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + if mc.cfg.TLS != nil || mc.cfg.Net == "unix" { // write cleartext auth packet err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) if err != nil { @@ -365,14 +366,24 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { return err } data[4] = cachingSha2PasswordRequestPublicKey - mc.writePacket(data) + err = mc.writePacket(data) + if err != nil { + return err + } - // parse public key if data, err = mc.readPacket(); err != nil { return err } - block, _ := pem.Decode(data[1:]) + if data[0] != iAuthMoreData { + return fmt.Errorf("unexpect resp from server for caching_sha2_password perform full authentication") + } + + // parse public key + block, rest := pem.Decode(data[1:]) + if block == nil { + return fmt.Errorf("No Pem data found, data: %s", rest) + } pkix, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return err @@ -401,6 +412,10 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { return nil // auth successful default: block, _ := pem.Decode(authData) + if block == nil { + return fmt.Errorf("no Pem data found, data: %s", authData) + } + pub, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return err diff --git a/auth_test.go b/auth_test.go index 1920ef39f..3ce0ea6e0 100644 --- a/auth_test.go +++ b/auth_test.go @@ -291,7 +291,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { // Hack to make the caching_sha2_password plugin believe that the connection // is secure - mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} // check written auth response authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 @@ -663,7 +663,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // hack to make the caching_sha2_password plugin believe that the connection // is secure - mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} @@ -676,7 +676,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { } // unset TLS config to prevent the actual establishment of a TLS wrapper - mc.cfg.tls = nil + mc.cfg.TLS = nil err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { @@ -866,7 +866,7 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { // Hack to make the caching_sha2_password plugin believe that the connection // is secure - mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} // auth switch request conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, @@ -1157,7 +1157,7 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { t.Errorf("got error: %v", err) } - expectedReply := []byte{1, 0, 0, 3, 0} + expectedReply := []byte{0, 0, 0, 3} if !bytes.Equal(conn.written, expectedReply) { t.Errorf("got unexpected data: %v", conn.written) } @@ -1184,7 +1184,7 @@ func TestOldAuthSwitchPasswordEmpty(t *testing.T) { t.Errorf("got error: %v", err) } - expectedReply := []byte{1, 0, 0, 3, 0} + expectedReply := []byte{0, 0, 0, 3} if !bytes.Equal(conn.written, expectedReply) { t.Errorf("got unexpected data: %v", conn.written) } @@ -1299,7 +1299,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { // Hack to make the caching_sha2_password plugin believe that the connection // is secure - mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} // auth switch request conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, diff --git a/benchmark_test.go b/benchmark_test.go index 3e25a3bf2..97ed781f8 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -127,7 +127,8 @@ func BenchmarkExec(b *testing.B) { } if _, err := stmt.Exec(); err != nil { - b.Fatal(err.Error()) + b.Logf("stmt.Exec failed: %v", err) + b.Fail() } } }() @@ -313,7 +314,7 @@ func BenchmarkExecContext(b *testing.B) { defer db.Close() for _, p := range []int{1, 2, 3, 4} { b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { - benchmarkQueryContext(b, db, p) + benchmarkExecContext(b, db, p) }) } } diff --git a/collations.go b/collations.go index 8d2b55676..295bfbe52 100644 --- a/collations.go +++ b/collations.go @@ -13,7 +13,8 @@ const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: -// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID +// +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID // // Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255. // @@ -247,7 +248,7 @@ var collations = map[string]byte{ "utf8mb4_0900_ai_ci": 255, } -// A blacklist of collations which is unsafe to interpolate parameters. +// A denylist of collations which is unsafe to interpolate parameters. // These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. var unsafeCollations = map[string]bool{ "big5_chinese_ci": true, diff --git a/conncheck.go b/conncheck.go index 024eb2858..0ea721720 100644 --- a/conncheck.go +++ b/conncheck.go @@ -6,6 +6,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos // +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos package mysql diff --git a/conncheck_dummy.go b/conncheck_dummy.go index ea7fb607a..a56c138f2 100644 --- a/conncheck_dummy.go +++ b/conncheck_dummy.go @@ -6,6 +6,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos // +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos package mysql diff --git a/conncheck_test.go b/conncheck_test.go index 53995517b..f7e025680 100644 --- a/conncheck_test.go +++ b/conncheck_test.go @@ -6,6 +6,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos // +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos package mysql diff --git a/connection.go b/connection.go index e4bb59e67..947a883e3 100644 --- a/connection.go +++ b/connection.go @@ -12,6 +12,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "io" "net" "strconv" @@ -46,9 +47,10 @@ type mysqlConn struct { // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { + var cmdSet strings.Builder for param, val := range mc.cfg.Params { switch param { - // Charset + // Charset: character_set_connection, character_set_client, character_set_results case "charset": charsets := strings.Split(val, ",") for i := range charsets { @@ -62,12 +64,25 @@ func (mc *mysqlConn) handleParams() (err error) { return } - // System Vars + // Other system vars accumulated in a single SET command default: - err = mc.exec("SET " + param + "=" + val + "") - if err != nil { - return + if cmdSet.Len() == 0 { + // Heuristic: 29 chars for each other key=value to reduce reallocations + cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) + cmdSet.WriteString("SET ") + } else { + cmdSet.WriteString(", ") } + cmdSet.WriteString(param) + cmdSet.WriteString(" = ") + cmdSet.WriteString(val) + } + } + + if cmdSet.Len() > 0 { + err = mc.exec(cmdSet.String()) + if err != nil { + return } } @@ -89,7 +104,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { } func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -108,7 +123,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if !mc.closed.IsSet() { + if !mc.closed.Load() { err = mc.writeCommandPacket(comQuit) } @@ -122,7 +137,7 @@ func (mc *mysqlConn) Close() (err error) { // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { - if !mc.closed.TrySet(true) { + if mc.closed.Swap(true) { return } @@ -137,7 +152,7 @@ func (mc *mysqlConn) cleanup() { } func (mc *mysqlConn) error() error { - if mc.closed.IsSet() { + if mc.closed.Load() { if err := mc.canceled.Value(); err != nil { return err } @@ -147,7 +162,7 @@ func (mc *mysqlConn) error() error { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -230,47 +245,21 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { - v := v.In(mc.cfg.Loc) - v = v.Add(time.Nanosecond * 500) // To round under microsecond - year := v.Year() - year100 := year / 100 - year1 := year % 100 - month := v.Month() - day := v.Day() - hour := v.Hour() - minute := v.Minute() - second := v.Second() - micro := v.Nanosecond() / 1000 - - buf = append(buf, []byte{ - '\'', - digits10[year100], digits01[year100], - digits10[year1], digits01[year1], - '-', - digits10[month], digits01[month], - '-', - digits10[day], digits01[day], - ' ', - digits10[hour], digits01[hour], - ':', - digits10[minute], digits01[minute], - ':', - digits10[second], digits01[second], - }...) - - if micro != 0 { - micro10000 := micro / 10000 - micro100 := micro / 100 % 100 - micro1 := micro % 100 - buf = append(buf, []byte{ - '.', - digits10[micro10000], digits01[micro10000], - digits10[micro100], digits01[micro100], - digits10[micro1], digits01[micro1], - }...) + buf = append(buf, '\'') + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) + if err != nil { + return "", err } buf = append(buf, '\'') } + case json.RawMessage: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') case []byte: if v == nil { buf = append(buf, "NULL"...) @@ -306,7 +295,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -367,7 +356,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro } func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -461,7 +450,7 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return driver.ErrBadConn } @@ -480,6 +469,10 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { // BeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if mc.closed.Load() { + return nil, driver.ErrBadConn + } + if err := mc.watchCancel(ctx); err != nil { return nil, err } @@ -643,9 +636,15 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { // ResetSession implements driver.SessionResetter. // (From Go 1.10) func (mc *mysqlConn) ResetSession(ctx context.Context) error { - if mc.closed.IsSet() { + if mc.closed.Load() { return driver.ErrBadConn } mc.reset = true return nil } + +// IsValid implements driver.Validator interface +// (From Go 1.15) +func (mc *mysqlConn) IsValid() bool { + return !mc.closed.Load() +} diff --git a/connection_test.go b/connection_test.go index 19c17ff8b..b6764a2f6 100644 --- a/connection_test.go +++ b/connection_test.go @@ -11,6 +11,7 @@ package mysql import ( "context" "database/sql/driver" + "encoding/json" "errors" "net" "testing" @@ -36,6 +37,33 @@ func TestInterpolateParams(t *testing.T) { } } +func TestInterpolateParamsJSONRawMessage(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + buf, err := json.Marshal(struct { + Value int `json:"value"` + }{Value: 42}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + q, err := mc.interpolateParams("SELECT ?", []driver.Value{json.RawMessage(buf)}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT '{\"value\":42}'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(nil), @@ -119,7 +147,7 @@ func TestCleanCancel(t *testing.T) { t.Errorf("expected context.Canceled, got %#v", err) } - if mc.closed.IsSet() { + if mc.closed.Load() { t.Error("expected mc is not closed, closed actually") } diff --git a/const.go b/const.go index b1e6b85ef..64e2bced6 100644 --- a/const.go +++ b/const.go @@ -10,7 +10,7 @@ package mysql const ( defaultAuthPlugin = "mysql_native_password" - defaultMaxAllowedPacket = 4 << 20 // 4 MiB + defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" diff --git a/driver.go b/driver.go index c1bdf1199..ad7aec215 100644 --- a/driver.go +++ b/driver.go @@ -8,10 +8,10 @@ // // The driver should be used via the database/sql package: // -// import "database/sql" -// import _ "github.com/go-sql-driver/mysql" +// import "database/sql" +// import _ "github.com/go-sql-driver/mysql" // -// db, err := sql.Open("mysql", "user:password@/dbname") +// db, err := sql.Open("mysql", "user:password@/dbname") // // See https://github.com/go-sql-driver/mysql#usage for details package mysql diff --git a/driver_test.go b/driver_test.go index ace083dfc..a1c776728 100644 --- a/driver_test.go +++ b/driver_test.go @@ -14,6 +14,7 @@ import ( "crypto/tls" "database/sql" "database/sql/driver" + "encoding/json" "fmt" "io" "io/ioutil" @@ -23,6 +24,7 @@ import ( "net/url" "os" "reflect" + "runtime" "strings" "sync" "sync/atomic" @@ -559,6 +561,29 @@ func TestRawBytes(t *testing.T) { }) } +func TestRawMessage(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + v1 := json.RawMessage("{}") + v2 := json.RawMessage("[]") + rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() + if rows.Next() { + var o1, o2 json.RawMessage + if err := rows.Scan(&o1, &o2); err != nil { + dbt.Errorf("Got error: %v", err) + } + if !bytes.Equal(v1, o1) { + dbt.Errorf("expected %v, got %v", v1, o1) + } + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + } else { + dbt.Errorf("no data") + } + }) +} + type testValuer struct { value string } @@ -1425,11 +1450,11 @@ func TestCharset(t *testing.T) { mustSetCharset("charset=ascii", "ascii") // when the first charset is invalid, use the second - mustSetCharset("charset=none,utf8", "utf8") + mustSetCharset("charset=none,utf8mb4", "utf8mb4") // when the first charset is valid, use it - mustSetCharset("charset=ascii,utf8", "ascii") - mustSetCharset("charset=utf8,ascii", "utf8") + mustSetCharset("charset=ascii,utf8mb4", "ascii") + mustSetCharset("charset=utf8mb4,ascii", "utf8mb4") } func TestFailingCharset(t *testing.T) { @@ -1454,7 +1479,7 @@ func TestCollation(t *testing.T) { defaultCollation, // driver default "latin1_general_ci", "binary", - "utf8_unicode_ci", + "utf8mb4_unicode_ci", "cp1257_bin", } @@ -1782,6 +1807,14 @@ func TestConcurrent(t *testing.T) { } runTests(t, dsn, func(dbt *DBTest) { + var version string + if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if strings.Contains(strings.ToLower(version), "mariadb") { + t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`) + } + var max int err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) if err != nil { @@ -2581,10 +2614,19 @@ func TestContextCancelStmtQuery(t *testing.T) { } func TestContextCancelBegin(t *testing.T) { + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + t.Skip(`FIXME: it sometime fails with "expected driver.ErrBadConn, got sql: connection is already closed" on windows and macOS`) + } + runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - tx, err := dbt.db.BeginTx(ctx, nil) + conn, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatal(err) + } + defer conn.Close() + tx, err := conn.BeginTx(ctx, nil) if err != nil { dbt.Fatal(err) } @@ -2614,7 +2656,17 @@ func TestContextCancelBegin(t *testing.T) { dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) } - // Context is canceled, so cannot begin a transaction. + // The connection is now in an inoperable state - so performing other + // operations should fail with ErrBadConn + // Important to exercise isolation level too - it runs SET TRANSACTION ISOLATION + // LEVEL XXX first, which needs to return ErrBadConn if the connection's context + // is cancelled + _, err = conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + if err != driver.ErrBadConn { + dbt.Errorf("expected driver.ErrBadConn, got %v", err) + } + + // cannot begin a transaction (on a different conn) with a canceled context if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } @@ -2651,7 +2703,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { if err := row.Scan(&v); err != nil { dbt.Fatal(err) } - // Because writer transaction wasn't commited yet, it should be available + // Because writer transaction wasn't committed yet, it should be available if v != 0 { dbt.Errorf("expected val to be 0, got %d", v) } @@ -2665,7 +2717,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { if err := row.Scan(&v); err != nil { dbt.Fatal(err) } - // Data written by writer transaction is already commited, it should be selectable + // Data written by writer transaction is already committed, it should be selectable if v != 1 { dbt.Errorf("expected val to be 1, got %d", v) } @@ -2719,13 +2771,13 @@ func TestRowsColumnTypes(t *testing.T) { nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} - nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} - nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} - nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} - nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} - nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} - nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} - ndNULL := NullTime{Time: time.Time{}, Valid: false} + nt0 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := sql.NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := sql.NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := sql.NullTime{Time: time.Time{}, Valid: false} rbNULL := sql.RawBytes(nil) rb0 := sql.RawBytes("0") rb42 := sql.RawBytes("42") @@ -2756,10 +2808,10 @@ func TestRowsColumnTypes(t *testing.T) { {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, - {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, - {"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, - {"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, - {"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "UNSIGNED TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "UNSIGNED SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "UNSIGNED BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "UNSIGNED INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, @@ -2805,131 +2857,125 @@ func TestRowsColumnTypes(t *testing.T) { values2 = values2[:len(values2)-2] values3 = values3[:len(values3)-2] - dsns := []string{ - dsn + "&parseTime=true", - dsn + "&parseTime=false", - } - for _, testdsn := range dsns { - runTests(t, testdsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (" + schema + ")") - dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") - rows, err := dbt.db.Query("SELECT * FROM test") - if err != nil { - t.Fatalf("Query: %v", err) - } + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } - tt, err := rows.ColumnTypes() - if err != nil { - t.Fatalf("ColumnTypes: %v", err) - } + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } - if len(tt) != len(columns) { - t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) - } + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } - types := make([]reflect.Type, len(tt)) - for i, tp := range tt { - column := columns[i] + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] - // Name - name := tp.Name() - if name != column.name { - t.Errorf("column name mismatch %s != %s", name, column.name) - continue - } + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } - // DatabaseTypeName - databaseTypeName := tp.DatabaseTypeName() - if databaseTypeName != column.databaseTypeName { - t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) - continue - } + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } - // ScanType - scanType := tp.ScanType() - if scanType != column.scanType { - if scanType == nil { - t.Errorf("scantype is null for column %q", name) - } else { - t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) - } - continue + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) } - types[i] = scanType - - // Nullable - nullable, ok := tp.Nullable() + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { if !ok { - t.Errorf("nullable not ok %q", name) - continue + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) } - if nullable != column.nullable { - t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) - } - - // Length - // length, ok := tp.Length() - // if length != column.length { - // if !ok { - // t.Errorf("length not ok for column %q", name) - // } else { - // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) - // } - // continue - // } - - // Precision and Scale - precision, scale, ok := tp.DecimalSize() - if precision != column.precision { - if !ok { - t.Errorf("precision not ok for column %q", name) - } else { - t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) - } - continue - } - if scale != column.scale { - if !ok { - t.Errorf("scale not ok for column %q", name) - } else { - t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) - } - continue + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) } + continue } + } - values := make([]interface{}, len(tt)) - for i := range values { - values[i] = reflect.New(types[i]).Interface() + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) } - i := 0 - for rows.Next() { - err = rows.Scan(values...) - if err != nil { - t.Fatalf("failed to scan values in %v", err) - } - for j := range values { - value := reflect.ValueOf(values[j]).Elem().Interface() - if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { - if columns[j].scanType == scanTypeRawBytes { - t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) - } else { - t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) - } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) } } - i++ - } - if i != 3 { - t.Errorf("expected 3 rows, got %d", i) } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } - if err := rows.Close(); err != nil { - t.Errorf("error closing rows: %s", err) - } - }) - } + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) } func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { diff --git a/dsn.go b/dsn.go index 75c8c2489..4b71aaab0 100644 --- a/dsn.go +++ b/dsn.go @@ -46,22 +46,23 @@ type Config struct { ServerPubKey string // Server public key name pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig Timeout time.Duration // Dial timeout ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout - AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE - AllowCleartextPasswords bool // Allows the cleartext client side plugin - AllowNativePasswords bool // Allows the native password authentication method - AllowOldPasswords bool // Allows the old insecure password method - CheckConnLiveness bool // Check connections for liveness before using them - ClientFoundRows bool // Return number of matching rows instead of rows changed - ColumnsWithAlias bool // Prepend table alias to column names - InterpolateParams bool // Interpolate placeholders into query string - MultiStatements bool // Allow multiple statements in one query - ParseTime bool // Parse time values to time.Time - RejectReadOnly bool // Reject read-only connections + AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE + AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS + AllowNativePasswords bool // Allows the native password authentication method + AllowOldPasswords bool // Allows the old insecure password method + CheckConnLiveness bool // Check connections for liveness before using them + ClientFoundRows bool // Return number of matching rows instead of rows changed + ColumnsWithAlias bool // Prepend table alias to column names + InterpolateParams bool // Interpolate placeholders into query string + MultiStatements bool // Allow multiple statements in one query + ParseTime bool // Parse time values to time.Time + RejectReadOnly bool // Reject read-only connections } // NewConfig creates a new Config and sets default values. @@ -77,8 +78,8 @@ func NewConfig() *Config { func (cfg *Config) Clone() *Config { cp := *cfg - if cp.tls != nil { - cp.tls = cfg.tls.Clone() + if cp.TLS != nil { + cp.TLS = cfg.TLS.Clone() } if len(cp.Params) > 0 { cp.Params = make(map[string]string, len(cfg.Params)) @@ -119,24 +120,29 @@ func (cfg *Config) normalize() error { cfg.Addr = ensureHavePort(cfg.Addr) } - switch cfg.TLSConfig { - case "false", "": - // don't set anything - case "true": - cfg.tls = &tls.Config{} - case "skip-verify", "preferred": - cfg.tls = &tls.Config{InsecureSkipVerify: true} - default: - cfg.tls = getTLSConfigClone(cfg.TLSConfig) - if cfg.tls == nil { - return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + if cfg.TLS == nil { + switch cfg.TLSConfig { + case "false", "": + // don't set anything + case "true": + cfg.TLS = &tls.Config{} + case "skip-verify": + cfg.TLS = &tls.Config{InsecureSkipVerify: true} + case "preferred": + cfg.TLS = &tls.Config{InsecureSkipVerify: true} + cfg.AllowFallbackToPlaintext = true + default: + cfg.TLS = getTLSConfigClone(cfg.TLSConfig) + if cfg.TLS == nil { + return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + } } } - if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + if cfg.TLS != nil && cfg.TLS.ServerName == "" && !cfg.TLS.InsecureSkipVerify { host, _, err := net.SplitHostPort(cfg.Addr) if err == nil { - cfg.tls.ServerName = host + cfg.TLS.ServerName = host } } @@ -204,6 +210,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true") } + if cfg.AllowFallbackToPlaintext { + writeDSNParam(&buf, &hasParam, "allowFallbackToPlaintext", "true") + } + if !cfg.AllowNativePasswords { writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false") } @@ -375,7 +385,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files + // Disable INFILE allowlist / enable all files case "allowAllFiles": var isBool bool cfg.AllowAllFiles, isBool = readBool(value) @@ -391,6 +401,14 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // Allow fallback to unencrypted connection if server does not support TLS + case "allowFallbackToPlaintext": + var isBool bool + cfg.AllowFallbackToPlaintext, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + // Use native password authentication case "allowNativePasswords": var isBool bool @@ -426,7 +444,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Collation case "collation": cfg.Collation = value - break case "columnsWithAlias": var isBool bool diff --git a/dsn_test.go b/dsn_test.go index 89815b341..41a6a29fa 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -42,8 +42,8 @@ var testDSNs = []struct { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, }, { - "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false}, + "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0&allowFallbackToPlaintext=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, @@ -82,7 +82,7 @@ func TestDSNParser(t *testing.T) { } // pointer not static - cfg.tls = nil + cfg.TLS = nil if !reflect.DeepEqual(cfg, tst.out) { t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) @@ -92,13 +92,15 @@ func TestDSNParser(t *testing.T) { func TestDSNParserInvalid(t *testing.T) { var invalidDSNs = []string{ - "@net(addr/", // no closing brace - "@tcp(/", // no closing brace - "tcp(/", // no closing brace - "(/", // no closing brace - "net(addr)//", // unescaped - "User:pass@tcp(1.2.3.4:3306)", // no trailing slash - "net()/", // unknown default addr + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + "User:pass@tcp(1.2.3.4:3306)", // no trailing slash + "net()/", // unknown default addr + "user:pass@tcp(127.0.0.1:3306)/db/name", // invalid dbname + "user:password@/dbname?allowFallbackToPlaintext=PREFERRED", // wrong bool flag //"/dbname?arg=/some/unescaped/path", } @@ -117,7 +119,7 @@ func TestDSNReformat(t *testing.T) { t.Error(err.Error()) continue } - cfg1.tls = nil // pointer not static + cfg1.TLS = nil // pointer not static res1 := fmt.Sprintf("%+v", cfg1) dsn2 := cfg1.FormatDSN() @@ -126,7 +128,7 @@ func TestDSNReformat(t *testing.T) { t.Error(err.Error()) continue } - cfg2.tls = nil // pointer not static + cfg2.TLS = nil // pointer not static res2 := fmt.Sprintf("%+v", cfg2) if res1 != res2 { @@ -202,7 +204,7 @@ func TestDSNWithCustomTLS(t *testing.T) { if err != nil { t.Error(err.Error()) - } else if cfg.tls.ServerName != name { + } else if cfg.TLS.ServerName != name { t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) } @@ -213,7 +215,7 @@ func TestDSNWithCustomTLS(t *testing.T) { if err != nil { t.Error(err.Error()) - } else if cfg.tls.ServerName != name { + } else if cfg.TLS.ServerName != name { t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) } else if tlsCfg.ServerName != "" { t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst) @@ -228,11 +230,11 @@ func TestDSNTLSConfig(t *testing.T) { if err != nil { t.Error(err.Error()) } - if cfg.tls == nil { + if cfg.TLS == nil { t.Error("cfg.tls should not be nil") } - if cfg.tls.ServerName != expectedServerName { - t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + if cfg.TLS.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.TLS.ServerName) } dsn = "tcp(example.com)/?tls=true" @@ -240,11 +242,11 @@ func TestDSNTLSConfig(t *testing.T) { if err != nil { t.Error(err.Error()) } - if cfg.tls == nil { + if cfg.TLS == nil { t.Error("cfg.tls should not be nil") } - if cfg.tls.ServerName != expectedServerName { - t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + if cfg.TLS.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.TLS.ServerName) } } @@ -261,7 +263,7 @@ func TestDSNWithCustomTLSQueryEscape(t *testing.T) { if err != nil { t.Error(err.Error()) - } else if cfg.tls.ServerName != name { + } else if cfg.TLS.ServerName != name { t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn) } } @@ -334,12 +336,12 @@ func TestCloneConfig(t *testing.T) { t.Errorf("Config.Clone did not create a separate config struct") } - if cfg2.tls.ServerName != expectedServerName { - t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + if cfg2.TLS.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.TLS.ServerName) } - cfg2.tls.ServerName = "example2.com" - if cfg.tls.ServerName == cfg2.tls.ServerName { + cfg2.TLS.ServerName = "example2.com" + if cfg.TLS.ServerName == cfg2.TLS.ServerName { t.Errorf("changed cfg.tls.Server name should not propagate to original Config") } @@ -383,20 +385,20 @@ func TestNormalizeTLSConfig(t *testing.T) { cfg.normalize() - if cfg.tls == nil { + if cfg.TLS == nil { if tc.want != nil { t.Fatal("wanted a tls config but got nil instead") } return } - if cfg.tls.ServerName != tc.want.ServerName { + if cfg.TLS.ServerName != tc.want.ServerName { t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')", - tc.want.ServerName, cfg.tls.ServerName) + tc.want.ServerName, cfg.TLS.ServerName) } - if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify { + if cfg.TLS.InsecureSkipVerify != tc.want.InsecureSkipVerify { t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)", - tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify) + tc.want.InsecureSkipVerify, cfg.TLS.InsecureSkipVerify) } }) } diff --git a/errors.go b/errors.go index 760782ff2..ff9a8f088 100644 --- a/errors.go +++ b/errors.go @@ -27,7 +27,7 @@ var ( ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") ErrPktSync = errors.New("commands out of sync. You can't run this command now") ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") - ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") + ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`") ErrBusyBuffer = errors.New("busy buffer") // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. @@ -56,10 +56,22 @@ func SetLogger(logger Logger) error { // MySQLError is an error type which represents a single MySQL error type MySQLError struct { - Number uint16 - Message string + Number uint16 + SQLState [5]byte + Message string } func (me *MySQLError) Error() string { + if me.SQLState != [5]byte{} { + return fmt.Sprintf("Error %d (%s): %s", me.Number, me.SQLState, me.Message) + } + return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } + +func (me *MySQLError) Is(err error) bool { + if merr, ok := err.(*MySQLError); ok { + return merr.Number == me.Number + } + return false +} diff --git a/errors_test.go b/errors_test.go index 96f9126d6..43213f98e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "errors" "log" "testing" ) @@ -40,3 +41,21 @@ func TestErrorsStrictIgnoreNotes(t *testing.T) { dbt.mustExec("DROP TABLE IF EXISTS does_not_exist") }) } + +func TestMySQLErrIs(t *testing.T) { + infraErr := &MySQLError{Number: 1234, Message: "the server is on fire"} + otherInfraErr := &MySQLError{Number: 1234, Message: "the datacenter is flooded"} + if !errors.Is(infraErr, otherInfraErr) { + t.Errorf("expected errors to be the same: %+v %+v", infraErr, otherInfraErr) + } + + differentCodeErr := &MySQLError{Number: 5678, Message: "the server is on fire"} + if errors.Is(infraErr, differentCodeErr) { + t.Fatalf("expected errors to be different: %+v %+v", infraErr, differentCodeErr) + } + + nonMysqlErr := errors.New("not a mysql error") + if errors.Is(infraErr, nonMysqlErr) { + t.Fatalf("expected errors to be different: %+v %+v", infraErr, nonMysqlErr) + } +} diff --git a/fields.go b/fields.go index e1e2ece4b..e0654a83d 100644 --- a/fields.go +++ b/fields.go @@ -41,6 +41,9 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeJSON: return "JSON" case fieldTypeLong: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED INT" + } return "INT" case fieldTypeLongBLOB: if mf.charSet != collations[binaryCollation] { @@ -48,6 +51,9 @@ func (mf *mysqlField) typeDatabaseName() string { } return "LONGBLOB" case fieldTypeLongLong: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED BIGINT" + } return "BIGINT" case fieldTypeMediumBLOB: if mf.charSet != collations[binaryCollation] { @@ -63,6 +69,9 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeSet: return "SET" case fieldTypeShort: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED SMALLINT" + } return "SMALLINT" case fieldTypeString: if mf.charSet == collations[binaryCollation] { @@ -74,6 +83,9 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeTimestamp: return "TIMESTAMP" case fieldTypeTiny: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED TINYINT" + } return "TINYINT" case fieldTypeTinyBLOB: if mf.charSet != collations[binaryCollation] { @@ -106,7 +118,7 @@ var ( scanTypeInt64 = reflect.TypeOf(int64(0)) scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) - scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) scanTypeUint8 = reflect.TypeOf(uint8(0)) scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint32 = reflect.TypeOf(uint32(0)) diff --git a/fuzz.go b/fuzz.go new file mode 100644 index 000000000..3a4ec25a9 --- /dev/null +++ b/fuzz.go @@ -0,0 +1,25 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build gofuzz +// +build gofuzz + +package mysql + +import ( + "database/sql" +) + +func Fuzz(data []byte) int { + db, err := sql.Open("mysql", string(data)) + if err != nil { + return 0 + } + db.Close() + return 1 +} diff --git a/go.mod b/go.mod index fffbf6a90..251110478 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/go-sql-driver/mysql -go 1.10 +go 1.13 diff --git a/infile.go b/infile.go index 273cb0ba5..3279dcffd 100644 --- a/infile.go +++ b/infile.go @@ -23,17 +23,16 @@ var ( readerRegisterLock sync.RWMutex ) -// RegisterLocalFile adds the given file to the file whitelist, +// RegisterLocalFile adds the given file to the file allowlist, // so that it can be used by "LOAD DATA LOCAL INFILE ". // Alternatively you can allow the use of all local files with // the DSN parameter 'allowAllFiles=true' // -// filePath := "/home/gopher/data.csv" -// mysql.RegisterLocalFile(filePath) -// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") -// if err != nil { -// ... -// +// filePath := "/home/gopher/data.csv" +// mysql.RegisterLocalFile(filePath) +// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") +// if err != nil { +// ... func RegisterLocalFile(filePath string) { fileRegisterLock.Lock() // lazy map init @@ -45,7 +44,7 @@ func RegisterLocalFile(filePath string) { fileRegisterLock.Unlock() } -// DeregisterLocalFile removes the given filepath from the whitelist. +// DeregisterLocalFile removes the given filepath from the allowlist. func DeregisterLocalFile(filePath string) { fileRegisterLock.Lock() delete(fileRegister, strings.Trim(filePath, `"`)) @@ -58,15 +57,14 @@ func DeregisterLocalFile(filePath string) { // If the handler returns a io.ReadCloser Close() is called when the // request is finished. // -// mysql.RegisterReaderHandler("data", func() io.Reader { -// var csvReader io.Reader // Some Reader that returns CSV data -// ... // Open Reader here -// return csvReader -// }) -// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") -// if err != nil { -// ... -// +// mysql.RegisterReaderHandler("data", func() io.Reader { +// var csvReader io.Reader // Some Reader that returns CSV data +// ... // Open Reader here +// return csvReader +// }) +// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") +// if err != nil { +// ... func RegisterReaderHandler(name string, handler func() io.Reader) { readerRegisterLock.Lock() // lazy map init @@ -93,10 +91,12 @@ func deferredClose(err *error, closer io.Closer) { } } +const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + func (mc *mysqlConn) handleInFileRequest(name string) (err error) { var rdr io.Reader var data []byte - packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + packetSize := defaultPacketSize if mc.maxWriteSize < packetSize { packetSize = mc.maxWriteSize } diff --git a/nulltime.go b/nulltime.go index afa8a89e9..36c8a42c5 100644 --- a/nulltime.go +++ b/nulltime.go @@ -9,11 +9,32 @@ package mysql import ( + "database/sql" "database/sql/driver" "fmt" "time" ) +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// # This NullTime implementation is not driver-specific +// +// Deprecated: NullTime doesn't honor the loc DSN parameter. +// NullTime.Scan interprets a time as UTC, not the loc DSN parameter. +// Use sql.NullTime instead. +type NullTime sql.NullTime + // Scan implements the Scanner interface. // The value type must be time.Time or string / []byte (formatted time-string), // otherwise Scan fails. @@ -28,11 +49,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) { nt.Time, nt.Valid = v, true return case []byte: - nt.Time, err = parseDateTime(string(v), time.UTC) + nt.Time, err = parseDateTime(v, time.UTC) nt.Valid = (err == nil) return case string: - nt.Time, err = parseDateTime(v, time.UTC) + nt.Time, err = parseDateTime([]byte(v), time.UTC) nt.Valid = (err == nil) return } diff --git a/nulltime_go113.go b/nulltime_go113.go deleted file mode 100644 index c392594dd..000000000 --- a/nulltime_go113.go +++ /dev/null @@ -1,31 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build go1.13 - -package mysql - -import ( - "database/sql" -) - -// NullTime represents a time.Time that may be NULL. -// NullTime implements the Scanner interface so -// it can be used as a scan destination: -// -// var nt NullTime -// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) -// ... -// if nt.Valid { -// // use nt.Time -// } else { -// // NULL value -// } -// -// This NullTime implementation is not driver-specific -type NullTime sql.NullTime diff --git a/nulltime_legacy.go b/nulltime_legacy.go deleted file mode 100644 index 86d159d44..000000000 --- a/nulltime_legacy.go +++ /dev/null @@ -1,34 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build !go1.13 - -package mysql - -import ( - "time" -) - -// NullTime represents a time.Time that may be NULL. -// NullTime implements the Scanner interface so -// it can be used as a scan destination: -// -// var nt NullTime -// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) -// ... -// if nt.Valid { -// // use nt.Time -// } else { -// // NULL value -// } -// -// This NullTime implementation is not driver-specific -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} diff --git a/packets.go b/packets.go index 82ad7a200..ee05c95a8 100644 --- a/packets.go +++ b/packets.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -109,14 +110,13 @@ func (mc *mysqlConn) writePacket(data []byte) error { conn = mc.rawConn } var err error - // If this connection has a ReadTimeout which we've been setting on - // reads, reset it to its default value before we attempt a non-blocking - // read, otherwise the scheduler will just time us out before we can read - if mc.cfg.ReadTimeout != 0 { - err = conn.SetReadDeadline(time.Time{}) - } - if err == nil && mc.cfg.CheckConnLiveness { - err = connCheck(conn) + if mc.cfg.CheckConnLiveness { + if mc.cfg.ReadTimeout != 0 { + err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) + } + if err == nil { + err = connCheck(conn) + } } if err != nil { errLog.Print("closing bad idle connection: ", err) @@ -222,9 +222,9 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if mc.flags&clientProtocol41 == 0 { return nil, "", ErrOldProtocol } - if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - if mc.cfg.TLSConfig == "preferred" { - mc.cfg.tls = nil + if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { + if mc.cfg.AllowFallbackToPlaintext { + mc.cfg.TLS = nil } else { return nil, "", ErrNoTLS } @@ -292,7 +292,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // To enable TLS / SSL - if mc.cfg.tls != nil { + if mc.cfg.TLS != nil { clientFlags |= clientSSL } @@ -348,16 +348,22 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return errors.New("unknown collation") } + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest - if mc.cfg.tls != nil { + if mc.cfg.TLS != nil { // Send TLS / SSL request packet if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { return err } // Switch to TLS - tlsConn := tls.Client(mc.netConn, mc.cfg.tls) + tlsConn := tls.Client(mc.netConn, mc.cfg.TLS) if err := tlsConn.Handshake(); err != nil { return err } @@ -366,12 +372,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string mc.buf.nc = tlsConn } - // Filler [23 bytes] (all 0x00) - pos := 13 - for ; pos < 13+23; pos++ { - data[pos] = 0 - } - // User [null terminated string] if len(mc.cfg.User) > 0 { pos += copy(data[pos:], mc.cfg.User) @@ -587,19 +587,20 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { return driver.ErrBadConn } + me := &MySQLError{Number: errno} + pos := 3 // SQL State [optional: # + 5bytes string] if data[3] == 0x23 { - //sqlstate := string(data[4 : 4+5]) + copy(me.SQLState[:], data[4:4+5]) pos = 9 } // Error Message [string] - return &MySQLError{ - Number: errno, - Message: string(data[pos:]), - } + me.Message = string(data[pos:]) + + return me } func readStatus(b []byte) statusFlag { @@ -760,40 +761,40 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // RowSet Packet - var n int - var isNull bool - pos := 0 + var ( + n int + isNull bool + pos int = 0 + ) for i := range dest { // Read bytes and convert to string dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) pos += n - if err == nil { - if !isNull { - if !mc.parseTime { - continue - } else { - switch rows.rs.columns[i].fieldType { - case fieldTypeTimestamp, fieldTypeDateTime, - fieldTypeDate, fieldTypeNewDate: - dest[i], err = parseDateTime( - string(dest[i].([]byte)), - mc.cfg.Loc, - ) - if err == nil { - continue - } - default: - continue - } - } - } else { - dest[i] = nil - continue + if err != nil { + return err + } + + if isNull { + dest[i] = nil + continue + } + + if !mc.parseTime { + continue + } + + // Parse time field + switch rows.rs.columns[i].fieldType { + case fieldTypeTimestamp, + fieldTypeDateTime, + fieldTypeDate, + fieldTypeNewDate: + if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil { + return err } } - return err // err != nil } return nil @@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { continue } + if v, ok := arg.(json.RawMessage); ok { + arg = []byte(v) + } // cache types and values switch v := arg.(type) { case int64: @@ -1112,7 +1116,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { b = append(b, "0000-00-00"...) } else { - b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) + b, err = appendDateTime(b, v.In(mc.cfg.Loc)) + if err != nil { + return err + } } paramValues = appendLengthEncodedInteger(paramValues, diff --git a/statement.go b/statement.go index f7e370939..10ece8bd6 100644 --- a/statement.go +++ b/statement.go @@ -10,6 +10,7 @@ package mysql import ( "database/sql/driver" + "encoding/json" "fmt" "io" "reflect" @@ -22,7 +23,7 @@ type mysqlStmt struct { } func (stmt *mysqlStmt) Close() error { - if stmt.mc == nil || stmt.mc.closed.IsSet() { + if stmt.mc == nil || stmt.mc.closed.Load() { // driver.Stmt.Close can be called more than once, thus this function // has to be idempotent. // See also Issue #450 and golang/go#16019. @@ -43,8 +44,13 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } +func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.closed.IsSet() { + if stmt.mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -92,7 +98,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { - if stmt.mc.closed.IsSet() { + if stmt.mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -129,6 +135,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { return rows, err } +var jsonType = reflect.TypeOf(json.RawMessage{}) + type converter struct{} // ConvertValue mirrors the reference/default converter in database/sql/driver @@ -146,12 +154,17 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if err != nil { return nil, err } - if !driver.IsValue(sv) { - return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + if driver.IsValue(sv) { + return sv, nil } - return sv, nil + // A value returned from the Valuer interface can be "a type handled by + // a database driver's NamedValueChecker interface" so we should accept + // uint64 here as well. + if u, ok := sv.(uint64); ok { + return u, nil + } + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) } - rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: @@ -170,11 +183,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { case reflect.Bool: return rv.Bool(), nil case reflect.Slice: - ek := rv.Type().Elem().Kind() - if ek == reflect.Uint8 { + switch t := rv.Type(); { + case t == jsonType: + return v, nil + case t.Elem().Kind() == reflect.Uint8: return rv.Bytes(), nil + default: + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind()) } - return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/statement_test.go b/statement_test.go index 4b9914f8e..2563ece55 100644 --- a/statement_test.go +++ b/statement_test.go @@ -10,6 +10,8 @@ package mysql import ( "bytes" + "database/sql/driver" + "encoding/json" "testing" ) @@ -34,7 +36,7 @@ func TestConvertDerivedByteSlice(t *testing.T) { t.Fatal("Byte slice not convertible", err) } - if bytes.Compare(output.([]byte), []byte("value")) != 0 { + if !bytes.Equal(output.([]byte), []byte("value")) { t.Fatalf("Byte slice not converted, got %#v %T", output, output) } } @@ -95,6 +97,14 @@ func TestConvertSignedIntegers(t *testing.T) { } } +type myUint64 struct { + value uint64 +} + +func (u myUint64) Value() (driver.Value, error) { + return u.value, nil +} + func TestConvertUnsignedIntegers(t *testing.T) { values := []interface{}{ uint8(42), @@ -102,6 +112,7 @@ func TestConvertUnsignedIntegers(t *testing.T) { uint32(42), uint64(42), uint(42), + myUint64{uint64(42)}, } for _, value := range values { @@ -124,3 +135,17 @@ func TestConvertUnsignedIntegers(t *testing.T) { t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) } } + +func TestConvertJSON(t *testing.T) { + raw := json.RawMessage("{}") + + out, err := converter{}.ConvertValue(raw) + + if err != nil { + t.Fatal("json.RawMessage was failed in convert", err) + } + + if _, ok := out.(json.RawMessage); !ok { + t.Fatalf("json.RawMessage converted, got %#v %T", out, out) + } +} diff --git a/transaction.go b/transaction.go index 417d72793..4a4b61001 100644 --- a/transaction.go +++ b/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.closed.IsSet() { + if tx.mc == nil || tx.mc.closed.Load() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.closed.IsSet() { + if tx.mc == nil || tx.mc.closed.Load() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") diff --git a/utils.go b/utils.go index 9552e80b5..15dbd8d16 100644 --- a/utils.go +++ b/utils.go @@ -35,26 +35,25 @@ var ( // Note: The provided tls.Config is exclusively owned by the driver after // registering it. // -// rootCertPool := x509.NewCertPool() -// pem, err := ioutil.ReadFile("/path/ca-cert.pem") -// if err != nil { -// log.Fatal(err) -// } -// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { -// log.Fatal("Failed to append PEM.") -// } -// clientCert := make([]tls.Certificate, 0, 1) -// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") -// if err != nil { -// log.Fatal(err) -// } -// clientCert = append(clientCert, certs) -// mysql.RegisterTLSConfig("custom", &tls.Config{ -// RootCAs: rootCertPool, -// Certificates: clientCert, -// }) -// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") -// +// rootCertPool := x509.NewCertPool() +// pem, err := ioutil.ReadFile("/path/ca-cert.pem") +// if err != nil { +// log.Fatal(err) +// } +// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { +// log.Fatal("Failed to append PEM.") +// } +// clientCert := make([]tls.Certificate, 0, 1) +// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") +// if err != nil { +// log.Fatal(err) +// } +// clientCert = append(clientCert, certs) +// mysql.RegisterTLSConfig("custom", &tls.Config{ +// RootCAs: rootCertPool, +// Certificates: clientCert, +// }) +// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") func RegisterTLSConfig(key string, config *tls.Config) error { if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { return fmt.Errorf("key '%s' is reserved", key) @@ -106,27 +105,126 @@ func readBool(input string) (value bool, valid bool) { * Time related utils * ******************************************************************************/ -func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { - base := "0000-00-00 00:00:00.0000000" - switch len(str) { +func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { + const base = "0000-00-00 00:00:00.000000" + switch len(b) { case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" - if str == base[:len(str)] { - return + if string(b) == base[:len(b)] { + return time.Time{}, nil + } + + year, err := parseByteYear(b) + if err != nil { + return time.Time{}, err + } + if b[4] != '-' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4]) + } + + m, err := parseByte2Digits(b[5], b[6]) + if err != nil { + return time.Time{}, err + } + month := time.Month(m) + + if b[7] != '-' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7]) + } + + day, err := parseByte2Digits(b[8], b[9]) + if err != nil { + return time.Time{}, err + } + if len(b) == 10 { + return time.Date(year, month, day, 0, 0, 0, 0, loc), nil + } + + if b[10] != ' ' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10]) + } + + hour, err := parseByte2Digits(b[11], b[12]) + if err != nil { + return time.Time{}, err + } + if b[13] != ':' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13]) + } + + min, err := parseByte2Digits(b[14], b[15]) + if err != nil { + return time.Time{}, err + } + if b[16] != ':' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16]) + } + + sec, err := parseByte2Digits(b[17], b[18]) + if err != nil { + return time.Time{}, err } - t, err = time.Parse(timeFormat[:len(str)], str) + if len(b) == 19 { + return time.Date(year, month, day, hour, min, sec, 0, loc), nil + } + + if b[19] != '.' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19]) + } + nsec, err := parseByteNanoSec(b[20:]) + if err != nil { + return time.Time{}, err + } + return time.Date(year, month, day, hour, min, sec, nsec, loc), nil default: - err = fmt.Errorf("invalid time string: %s", str) - return + return time.Time{}, fmt.Errorf("invalid time bytes: %s", b) } +} - // Adjust location - if err == nil && loc != time.UTC { - y, mo, d := t.Date() - h, mi, s := t.Clock() - t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil +func parseByteYear(b []byte) (int, error) { + year, n := 0, 1000 + for i := 0; i < 4; i++ { + v, err := bToi(b[i]) + if err != nil { + return 0, err + } + year += v * n + n /= 10 } + return year, nil +} - return +func parseByte2Digits(b1, b2 byte) (int, error) { + d1, err := bToi(b1) + if err != nil { + return 0, err + } + d2, err := bToi(b2) + if err != nil { + return 0, err + } + return d1*10 + d2, nil +} + +func parseByteNanoSec(b []byte) (int, error) { + ns, digit := 0, 100000 // max is 6-digits + for i := 0; i < len(b); i++ { + v, err := bToi(b[i]) + if err != nil { + return 0, err + } + ns += v * digit + digit /= 10 + } + // nanoseconds has 10-digits. (needs to scale digits) + // 10 - 6 = 4, so we have to multiple 1000. + return ns * 1000, nil +} + +func bToi(b byte) (int, error) { + if b < '0' || b > '9' { + return 0, errors.New("not [0-9]") + } + return int(b - '0'), nil } func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { @@ -167,6 +265,64 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } +func appendDateTime(buf []byte, t time.Time) ([]byte, error) { + year, month, day := t.Date() + hour, min, sec := t.Clock() + nsec := t.Nanosecond() + + if year < 1 || year > 9999 { + return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap + } + year100 := year / 100 + year1 := year % 100 + + var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape + localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1] + localBuf[4] = '-' + localBuf[5], localBuf[6] = digits10[month], digits01[month] + localBuf[7] = '-' + localBuf[8], localBuf[9] = digits10[day], digits01[day] + + if hour == 0 && min == 0 && sec == 0 && nsec == 0 { + return append(buf, localBuf[:10]...), nil + } + + localBuf[10] = ' ' + localBuf[11], localBuf[12] = digits10[hour], digits01[hour] + localBuf[13] = ':' + localBuf[14], localBuf[15] = digits10[min], digits01[min] + localBuf[16] = ':' + localBuf[17], localBuf[18] = digits10[sec], digits01[sec] + + if nsec == 0 { + return append(buf, localBuf[:19]...), nil + } + nsec100000000 := nsec / 100000000 + nsec1000000 := (nsec / 1000000) % 100 + nsec10000 := (nsec / 10000) % 100 + nsec100 := (nsec / 100) % 100 + nsec1 := nsec % 100 + localBuf[19] = '.' + + // milli second + localBuf[20], localBuf[21], localBuf[22] = + digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000] + // micro second + localBuf[23], localBuf[24], localBuf[25] = + digits10[nsec10000], digits01[nsec10000], digits10[nsec100] + // nano second + localBuf[26], localBuf[27], localBuf[28] = + digits01[nsec100], digits10[nsec1], digits01[nsec1] + + // trim trailing zeros + n := len(localBuf) + for n > 0 && localBuf[n-1] == '0' { + n-- + } + + return append(buf, localBuf[:n]...), nil +} + // zeroDateTime is used in formatBinaryDateTime to avoid an allocation // if the DATE or DATETIME has the zero value. // It must never be changed. @@ -375,7 +531,7 @@ func stringToInt(b []byte) int { return val } -// returns the string read as a bytes slice, wheter the value is NULL, +// returns the string read as a bytes slice, whether the value is NULL, // the number of bytes read and an error, in case the string is longer than // the input slice func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { @@ -485,32 +641,32 @@ func escapeBytesBackslash(buf, v []byte) []byte { for _, c := range v { switch c { case '\x00': - buf[pos] = '\\' buf[pos+1] = '0' + buf[pos] = '\\' pos += 2 case '\n': - buf[pos] = '\\' buf[pos+1] = 'n' + buf[pos] = '\\' pos += 2 case '\r': - buf[pos] = '\\' buf[pos+1] = 'r' + buf[pos] = '\\' pos += 2 case '\x1a': - buf[pos] = '\\' buf[pos+1] = 'Z' + buf[pos] = '\\' pos += 2 case '\'': - buf[pos] = '\\' buf[pos+1] = '\'' + buf[pos] = '\\' pos += 2 case '"': - buf[pos] = '\\' buf[pos+1] = '"' + buf[pos] = '\\' pos += 2 case '\\': - buf[pos] = '\\' buf[pos+1] = '\\' + buf[pos] = '\\' pos += 2 default: buf[pos] = c @@ -530,32 +686,32 @@ func escapeStringBackslash(buf []byte, v string) []byte { c := v[i] switch c { case '\x00': - buf[pos] = '\\' buf[pos+1] = '0' + buf[pos] = '\\' pos += 2 case '\n': - buf[pos] = '\\' buf[pos+1] = 'n' + buf[pos] = '\\' pos += 2 case '\r': - buf[pos] = '\\' buf[pos+1] = 'r' + buf[pos] = '\\' pos += 2 case '\x1a': - buf[pos] = '\\' buf[pos+1] = 'Z' + buf[pos] = '\\' pos += 2 case '\'': - buf[pos] = '\\' buf[pos+1] = '\'' + buf[pos] = '\\' pos += 2 case '"': - buf[pos] = '\\' buf[pos+1] = '"' + buf[pos] = '\\' pos += 2 case '\\': - buf[pos] = '\\' buf[pos+1] = '\\' + buf[pos] = '\\' pos += 2 default: buf[pos] = c @@ -577,8 +733,8 @@ func escapeBytesQuotes(buf, v []byte) []byte { for _, c := range v { if c == '\'' { - buf[pos] = '\'' buf[pos+1] = '\'' + buf[pos] = '\'' pos += 2 } else { buf[pos] = c @@ -597,8 +753,8 @@ func escapeStringQuotes(buf []byte, v string) []byte { for i := 0; i < len(v); i++ { c := v[i] if c == '\'' { - buf[pos] = '\'' buf[pos+1] = '\'' + buf[pos] = '\'' pos += 2 } else { buf[pos] = c @@ -623,39 +779,16 @@ type noCopy struct{} // Lock is a no-op used by -copylocks checker from `go vet`. func (*noCopy) Lock() {} -// atomicBool is a wrapper around uint32 for usage as a boolean value with -// atomic access. -type atomicBool struct { - _noCopy noCopy - value uint32 -} - -// IsSet returns whether the current boolean value is true -func (ab *atomicBool) IsSet() bool { - return atomic.LoadUint32(&ab.value) > 0 -} - -// Set sets the value of the bool regardless of the previous value -func (ab *atomicBool) Set(value bool) { - if value { - atomic.StoreUint32(&ab.value, 1) - } else { - atomic.StoreUint32(&ab.value, 0) - } -} - -// TrySet sets the value of the bool and returns whether the value changed -func (ab *atomicBool) TrySet(value bool) bool { - if value { - return atomic.SwapUint32(&ab.value, 1) == 0 - } - return atomic.SwapUint32(&ab.value, 0) > 0 -} +// Unlock is a no-op used by -copylocks checker from `go vet`. +// noCopy should implement sync.Locker from Go 1.11 +// https://github.com/golang/go/commit/c2eba53e7f80df21d51285879d51ab81bcfbf6bc +// https://github.com/golang/go/issues/26165 +func (*noCopy) Unlock() {} // atomicError is a wrapper for atomically accessed error values type atomicError struct { - _noCopy noCopy - value atomic.Value + _ noCopy + value atomic.Value } // Set sets the error value regardless of the previous value. diff --git a/utils_test.go b/utils_test.go index 10a60c2d0..4e5fc3cb7 100644 --- a/utils_test.go +++ b/utils_test.go @@ -14,6 +14,7 @@ import ( "database/sql/driver" "encoding/binary" "testing" + "time" ) func TestLengthEncodedInteger(t *testing.T) { @@ -172,64 +173,6 @@ func TestEscapeQuotes(t *testing.T) { expect("foo\"bar", "foo\"bar") // not affected } -func TestAtomicBool(t *testing.T) { - var ab atomicBool - if ab.IsSet() { - t.Fatal("Expected value to be false") - } - - ab.Set(true) - if ab.value != 1 { - t.Fatal("Set(true) did not set value to 1") - } - if !ab.IsSet() { - t.Fatal("Expected value to be true") - } - - ab.Set(true) - if !ab.IsSet() { - t.Fatal("Expected value to be true") - } - - ab.Set(false) - if ab.value != 0 { - t.Fatal("Set(false) did not set value to 0") - } - if ab.IsSet() { - t.Fatal("Expected value to be false") - } - - ab.Set(false) - if ab.IsSet() { - t.Fatal("Expected value to be false") - } - if ab.TrySet(false) { - t.Fatal("Expected TrySet(false) to fail") - } - if !ab.TrySet(true) { - t.Fatal("Expected TrySet(true) to succeed") - } - if !ab.IsSet() { - t.Fatal("Expected value to be true") - } - - ab.Set(true) - if !ab.IsSet() { - t.Fatal("Expected value to be true") - } - if ab.TrySet(true) { - t.Fatal("Expected TrySet(true) to fail") - } - if !ab.TrySet(false) { - t.Fatal("Expected TrySet(false) to succeed") - } - if ab.IsSet() { - t.Fatal("Expected value to be false") - } - - ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯ -} - func TestAtomicError(t *testing.T) { var ae atomicError if ae.Value() != nil { @@ -291,3 +234,244 @@ func TestIsolationLevelMapping(t *testing.T) { t.Fatalf("Expected error to be %q, got %q", expectedErr, err) } } + +func TestAppendDateTime(t *testing.T) { + tests := []struct { + t time.Time + str string + }{ + { + t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC), + str: "1234-05-06", + }, + { + t: time.Date(4567, 12, 31, 12, 0, 0, 0, time.UTC), + str: "4567-12-31 12:00:00", + }, + { + t: time.Date(2020, 5, 30, 12, 34, 0, 0, time.UTC), + str: "2020-05-30 12:34:00", + }, + { + t: time.Date(2020, 5, 30, 12, 34, 56, 0, time.UTC), + str: "2020-05-30 12:34:56", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123000000, time.UTC), + str: "2020-05-30 22:33:44.123", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456000, time.UTC), + str: "2020-05-30 22:33:44.123456", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456789, time.UTC), + str: "2020-05-30 22:33:44.123456789", + }, + { + t: time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC), + str: "9999-12-31 23:59:59.999999999", + }, + { + t: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), + str: "0001-01-01", + }, + } + for _, v := range tests { + buf := make([]byte, 0, 32) + buf, _ = appendDateTime(buf, v.t) + if str := string(buf); str != v.str { + t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str) + } + } + + // year out of range + { + v := time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) + buf := make([]byte, 0, 32) + _, err := appendDateTime(buf, v) + if err == nil { + t.Error("want an error") + return + } + } + { + v := time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC) + buf := make([]byte, 0, 32) + _, err := appendDateTime(buf, v) + if err == nil { + t.Error("want an error") + return + } + } +} + +func TestParseDateTime(t *testing.T) { + cases := []struct { + name string + str string + }{ + { + name: "parse date", + str: "2020-05-13", + }, + { + name: "parse null date", + str: sDate0, + }, + { + name: "parse datetime", + str: "2020-05-13 21:30:45", + }, + { + name: "parse null datetime", + str: sDateTime0, + }, + { + name: "parse datetime nanosec 1-digit", + str: "2020-05-25 23:22:01.1", + }, + { + name: "parse datetime nanosec 2-digits", + str: "2020-05-25 23:22:01.15", + }, + { + name: "parse datetime nanosec 3-digits", + str: "2020-05-25 23:22:01.159", + }, + { + name: "parse datetime nanosec 4-digits", + str: "2020-05-25 23:22:01.1594", + }, + { + name: "parse datetime nanosec 5-digits", + str: "2020-05-25 23:22:01.15949", + }, + { + name: "parse datetime nanosec 6-digits", + str: "2020-05-25 23:22:01.159491", + }, + } + + for _, loc := range []*time.Location{ + time.UTC, + time.FixedZone("test", 8*60*60), + } { + for _, cc := range cases { + t.Run(cc.name+"-"+loc.String(), func(t *testing.T) { + var want time.Time + if cc.str != sDate0 && cc.str != sDateTime0 { + var err error + want, err = time.ParseInLocation(timeFormat[:len(cc.str)], cc.str, loc) + if err != nil { + t.Fatal(err) + } + } + got, err := parseDateTime([]byte(cc.str), loc) + if err != nil { + t.Fatal(err) + } + + if !want.Equal(got) { + t.Fatalf("want: %v, but got %v", want, got) + } + }) + } + } +} + +func TestInvalidDateTime(t *testing.T) { + cases := []struct { + name string + str string + want time.Time + }{ + { + name: "parse datetime without day", + str: "0000-00-00 21:30:45", + want: time.Date(0, 0, 0, 21, 30, 45, 0, time.UTC), + }, + } + + for _, cc := range cases { + t.Run(cc.name, func(t *testing.T) { + got, err := parseDateTime([]byte(cc.str), time.UTC) + if err != nil { + t.Fatal(err) + } + + if !cc.want.Equal(got) { + t.Fatalf("want: %v, but got %v", cc.want, got) + } + }) + } +} + +func TestParseDateTimeFail(t *testing.T) { + cases := []struct { + name string + str string + wantErr string + }{ + { + name: "parse invalid time", + str: "hello", + wantErr: "invalid time bytes: hello", + }, + { + name: "parse year", + str: "000!-00-00 00:00:00.000000", + wantErr: "not [0-9]", + }, + { + name: "parse month", + str: "0000-!0-00 00:00:00.000000", + wantErr: "not [0-9]", + }, + { + name: `parse "-" after parsed year`, + str: "0000:00-00 00:00:00.000000", + wantErr: "bad value for field: `:`", + }, + { + name: `parse "-" after parsed month`, + str: "0000-00:00 00:00:00.000000", + wantErr: "bad value for field: `:`", + }, + { + name: `parse " " after parsed date`, + str: "0000-00-00+00:00:00.000000", + wantErr: "bad value for field: `+`", + }, + { + name: `parse ":" after parsed date`, + str: "0000-00-00 00-00:00.000000", + wantErr: "bad value for field: `-`", + }, + { + name: `parse ":" after parsed hour`, + str: "0000-00-00 00:00-00.000000", + wantErr: "bad value for field: `-`", + }, + { + name: `parse "." after parsed sec`, + str: "0000-00-00 00:00:00?000000", + wantErr: "bad value for field: `?`", + }, + } + + for _, cc := range cases { + t.Run(cc.name, func(t *testing.T) { + got, err := parseDateTime([]byte(cc.str), time.UTC) + if err == nil { + t.Fatal("want error") + } + if cc.wantErr != err.Error() { + t.Fatalf("want `%s`, but got `%s`", cc.wantErr, err) + } + if !got.IsZero() { + t.Fatal("want zero time") + } + }) + } +}