diff --git a/.buildkite/engineer b/.buildkite/engineer index 880db1967231..71c1211a83b4 100755 --- a/.buildkite/engineer +++ b/.buildkite/engineer @@ -54,7 +54,7 @@ fi # Check if the system has engineer installed, if not, use a local copy. if ! type "engineer" &> /dev/null; then # Setup Prisma engine build & test tool (engineer). - curl --fail -sSL "https://prisma-engineer.s3-eu-west-1.amazonaws.com/1.62/latest/$OS/engineer.gz" --output engineer.gz + curl --fail -sSL "https://prisma-engineer.s3-eu-west-1.amazonaws.com/1.63/latest/$OS/engineer.gz" --output engineer.gz gzip -d engineer.gz chmod +x engineer diff --git a/.envrc b/.envrc index 48b1254c1700..5488da9e10e7 100644 --- a/.envrc +++ b/.envrc @@ -23,7 +23,7 @@ export QE_LOG_LEVEL=debug # Set it to "trace" to enable query-graph debugging lo # export FMT_SQL=1 # Uncomment it to enable logging formatted SQL queries ### Uncomment to run driver adapters tests. See query-engine-driver-adapters.yml workflow for how tests run in CI. -# export EXTERNAL_TEST_EXECUTOR="$(pwd)/query-engine/driver-adapters/js/connector-test-kit-executor/script/start_node.sh" +# export EXTERNAL_TEST_EXECUTOR="napi" # export DRIVER_ADAPTER=pg # Set to pg, neon or planetscale # export PRISMA_DISABLE_QUAINT_EXECUTORS=1 # Disable quaint executors for driver adapters # export DRIVER_ADAPTER_URL_OVERRIDE ="postgres://USER:PASSWORD@DATABASExxxx" # Override the database url for the driver adapter tests diff --git a/.github/workflows/qe-wasm-check.yml b/.github/workflows/qe-wasm-check.yml new file mode 100644 index 000000000000..f67d2d247b27 --- /dev/null +++ b/.github/workflows/qe-wasm-check.yml @@ -0,0 +1,27 @@ +name: WASM engine compile check +on: + push: + branches: + - main + pull_request: + paths-ignore: + - '.github/**' + - '!.github/workflows/qe-wasm-check.yml' + - '.buildkite/**' + - '*.md' + - 'LICENSE' + - 'CODEOWNERS' + - 'renovate.json' + +jobs: + build: + name: 'Compilation check for query-engine-wasm' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - name: Install wasm-pack + run: cargo install wasm-pack + - name: Build wasm query engine + run: ./build.sh + working-directory: ./query-engine/query-engine-wasm diff --git a/.github/workflows/query-engine-driver-adapters.yml b/.github/workflows/query-engine-driver-adapters.yml index 3de0238aa0e7..f6eec5ffc102 100644 --- a/.github/workflows/query-engine-driver-adapters.yml +++ b/.github/workflows/query-engine-driver-adapters.yml @@ -25,12 +25,22 @@ jobs: fail-fast: false matrix: adapter: - - name: 'pg' + - name: 'pg (napi)' setup_task: 'dev-pg-postgres13' - - name: 'neon:ws' + - name: 'neon:ws (napi)' setup_task: 'dev-neon-ws-postgres13' - - name: 'libsql' + - name: 'libsql (napi)' setup_task: 'dev-libsql-sqlite' + # TODO: uncomment when WASM engine is functional + # - name: 'pg (wasm)' + # setup_task: 'dev-pg-postgres13-wasm' + # needs_wasm_pack: true + # - name: 'neon:ws (wasm)' + # setup_task: 'dev-neon-ws-postgres13-wasm' + # needs_wasm_pack: true + # - name: 'libsql (wasm)' + # setup_task: 'dev-libsql-sqlite-wasm' + # needs_wasm_pack: true node_version: ['18'] env: LOG_LEVEL: 'info' # Set to "debug" to trace the query engine and node process running the driver adapter @@ -85,9 +95,13 @@ jobs: echo "DRIVER_ADAPTERS_BRANCH=$branch" >> "$GITHUB_ENV" fi - - run: make ${{ matrix.adapter.setup_task }} - - uses: dtolnay/rust-toolchain@stable + - name: 'Install wasm-pack' + if: ${{ matrix.adapter.needs_wasm_pack }} + run: cargo install wasm-pack + + - run: make ${{ matrix.adapter.setup_task }} + - name: 'Run tests' run: cargo test --package query-engine-tests -- --test-threads=1 diff --git a/.github/workflows/query-engine.yml b/.github/workflows/query-engine.yml index 762c3da4a50a..32dc854a2617 100644 --- a/.github/workflows/query-engine.yml +++ b/.github/workflows/query-engine.yml @@ -25,10 +25,6 @@ jobs: fail-fast: false matrix: database: - - name: 'vitess_5_7' - single_threaded: true - connector: 'vitess' - version: '5.7' - name: 'vitess_8_0' single_threaded: true connector: 'vitess' @@ -41,6 +37,10 @@ jobs: single_threaded: false connector: 'sqlserver' version: '2022' + - name: 'sqlite' + single_threaded: false + connector: 'sqlite' + version: '3' - name: 'mongodb_4_2' single_threaded: true connector: 'mongodb' diff --git a/.github/workflows/schema-engine.yml b/.github/workflows/schema-engine.yml index 03d23317bbd0..36f55368bf58 100644 --- a/.github/workflows/schema-engine.yml +++ b/.github/workflows/schema-engine.yml @@ -94,11 +94,6 @@ jobs: url: 'postgresql://prisma@localhost:26257' - name: sqlite url: sqlite - - name: vitess_5_7 - url: 'mysql://root:prisma@localhost:33577/test' - shadow_database_url: 'mysql://root:prisma@localhost:33578/shadow' - is_vitess: true - single_threaded: true - name: vitess_8_0 url: 'mysql://root:prisma@localhost:33807/test' shadow_database_url: 'mysql://root:prisma@localhost:33808/shadow' diff --git a/.test_database_urls/vitess_5_7 b/.test_database_urls/vitess_5_7 deleted file mode 100644 index 2259628658ac..000000000000 --- a/.test_database_urls/vitess_5_7 +++ /dev/null @@ -1,2 +0,0 @@ -export TEST_DATABASE_URL="mysql://root:prisma@localhost:33577/test" -export TEST_SHADOW_DATABASE_URL="mysql://root:prisma@localhost:33578/shadow" \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 573e31eababd..cd2fe6c6b5c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3571,6 +3571,7 @@ dependencies = [ "connection-string", "either", "futures", + "getrandom 0.2.10", "hex", "indoc 0.3.6", "lru-cache", @@ -3679,6 +3680,7 @@ dependencies = [ "once_cell", "opentelemetry", "petgraph 0.4.13", + "pin-project", "prisma-models", "psl", "query-connector", @@ -3694,6 +3696,7 @@ dependencies = [ "tracing-subscriber", "user-facing-errors", "uuid", + "wasm-bindgen-futures", ] [[package]] @@ -3821,9 +3824,14 @@ dependencies = [ "log", "prisma-models", "psl", + "quaint", + "query-connector", + "query-core", + "request-handlers", "serde", "serde-wasm-bindgen", "serde_json", + "sql-query-connector", "thiserror", "tokio", "tracing", @@ -6010,9 +6018,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -6020,9 +6028,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" dependencies = [ "bumpalo", "log", @@ -6047,9 +6055,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6057,9 +6065,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" dependencies = [ "proc-macro2", "quote", @@ -6070,9 +6078,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.87" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" [[package]] name = "wasm-logger" diff --git a/Cargo.toml b/Cargo.toml index 4a3cd1450caf..b32a1a85cf18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ features = [ "pooled", "postgresql", "sqlite", + "native", ] [profile.dev.package.backtrace] diff --git a/Makefile b/Makefile index e00c122e2713..94b3d539f971 100644 --- a/Makefile +++ b/Makefile @@ -49,8 +49,11 @@ ifndef DRIVER_ADAPTER cargo test --package query-engine-tests else @echo "Executing query engine tests with $(DRIVER_ADAPTER) driver adapter"; \ - # Add your actual command for the "test-driver-adapter" task here - $(MAKE) test-driver-adapter-$(DRIVER_ADAPTER); + if [ "$(ENGINE)" = "wasm" ]; then \ + $(MAKE) test-driver-adapter-$(DRIVER_ADAPTER)-wasm; \ + else \ + $(MAKE) test-driver-adapter-$(DRIVER_ADAPTER); \ + fi endif test-qe-verbose: @@ -91,6 +94,12 @@ test-libsql-sqlite: dev-libsql-sqlite test-qe-st test-driver-adapter-libsql: test-libsql-sqlite +dev-libsql-sqlite-wasm: build-qe-wasm build-connector-kit-js + cp $(CONFIG_PATH)/libsql-sqlite-wasm $(CONFIG_FILE) + +test-libsql-sqlite-wasm: dev-libsql-sqlite-wasm test-qe-st +test-driver-adapter-libsql-sqlite-wasm: test-libsql-sqlite-wasm + start-postgres9: docker compose -f docker-compose.yml up --wait -d --remove-orphans postgres9 @@ -121,14 +130,20 @@ start-postgres13: dev-postgres13: start-postgres13 cp $(CONFIG_PATH)/postgres13 $(CONFIG_FILE) -start-pg-postgres13: build-qe-napi build-connector-kit-js start-postgres13 +start-pg-postgres13: start-postgres13 -dev-pg-postgres13: start-pg-postgres13 +dev-pg-postgres13: start-pg-postgres13 build-qe-napi build-connector-kit-js cp $(CONFIG_PATH)/pg-postgres13 $(CONFIG_FILE) test-pg-postgres13: dev-pg-postgres13 test-qe-st +dev-pg-postgres13-wasm: start-pg-postgres13 build-qe-wasm build-connector-kit-js + cp $(CONFIG_PATH)/pg-postgres13-wasm $(CONFIG_FILE) + +test-pg-postgres13-wasm: dev-pg-postgres13-wasm test-qe-st + test-driver-adapter-pg: test-pg-postgres13 +test-driver-adapter-pg-wasm: test-pg-postgres13-wasm start-neon-postgres13: docker compose -f docker-compose.yml up --wait -d --remove-orphans neon-postgres13 @@ -138,7 +153,13 @@ dev-neon-ws-postgres13: start-neon-postgres13 build-qe-napi build-connector-kit- test-neon-ws-postgres13: dev-neon-ws-postgres13 test-qe-st +dev-neon-ws-postgres13-wasm: start-neon-postgres13 build-qe-wasm build-connector-kit-js + cp $(CONFIG_PATH)/neon-ws-postgres13-wasm $(CONFIG_FILE) + +test-neon-ws-postgres13-wasm: dev-neon-ws-postgres13-wasm test-qe-st + test-driver-adapter-neon: test-neon-ws-postgres13 +test-driver-adapter-neon-wasm: test-neon-ws-postgres13-wasm start-postgres14: docker compose -f docker-compose.yml up --wait -d --remove-orphans postgres14 @@ -256,12 +277,6 @@ dev-mongodb_5: start-mongodb_5 dev-mongodb_4_2: start-mongodb_4_2 cp $(CONFIG_PATH)/mongodb42 $(CONFIG_FILE) -start-vitess_5_7: - docker compose -f docker-compose.yml up --wait -d --remove-orphans vitess-test-5_7 vitess-shadow-5_7 - -dev-vitess_5_7: start-vitess_5_7 - cp $(CONFIG_PATH)/vitess_5_7 $(CONFIG_FILE) - start-vitess_8_0: docker compose -f docker-compose.yml up --wait -d --remove-orphans vitess-test-8_0 vitess-shadow-8_0 @@ -276,7 +291,13 @@ dev-planetscale-vitess8: start-planetscale-vitess8 build-qe-napi build-connector test-planetscale-vitess8: dev-planetscale-vitess8 test-qe-st +dev-planetscale-vitess8-wasm: start-planetscale-vitess8 build-qe-wasm build-connector-kit-js + cp $(CONFIG_PATH)/planetscale-vitess8-wasm $(CONFIG_FILE) + +test-planetscale-vitess8-wasm: dev-planetscale-vitess8-wasm test-qe-st + test-driver-adapter-planetscale: test-planetscale-vitess8 +test-driver-adapter-planetscale-wasm: test-planetscale-vitess8-wasm ###################### # Local dev commands # @@ -285,6 +306,9 @@ test-driver-adapter-planetscale: test-planetscale-vitess8 build-qe-napi: cargo build --package query-engine-node-api +build-qe-wasm: + cd query-engine/query-engine-wasm && ./build.sh + build-connector-kit-js: build-driver-adapters cd query-engine/driver-adapters && pnpm i && pnpm build diff --git a/docker-compose.yml b/docker-compose.yml index a8b48748abc4..b8fe3e1e0fa0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -222,26 +222,6 @@ services: - databases tmpfs: /var/lib/mariadb - vitess-test-5_7: - image: vitess/vttestserver:mysql57@sha256:23863a518b34330109c502ac61a396008f5f023e96263bcb2bb1b0f7f7d5dc7f - restart: unless-stopped - ports: - - 33577:33577 - environment: - PORT: 33574 - KEYSPACES: 'test' - NUM_SHARDS: '1' - MYSQL_BIND_HOST: '0.0.0.0' - FOREIGN_KEY_MODE: 'disallow' - ENABLE_ONLINE_DDL: false - MYSQL_MAX_CONNECTIONS: 100000 - TABLET_REFRESH_INTERVAL: '500ms' - healthcheck: - test: ['CMD', 'mysqladmin', 'ping', '-h127.0.0.1', '-P33577'] - interval: 5s - timeout: 2s - retries: 20 - vitess-test-8_0: image: vitess/vttestserver:mysql80@sha256:8bec2644d83cb322eb2cdd596d33c0f858243ba6ade9164c95dfcc519643094e restart: unless-stopped @@ -262,26 +242,6 @@ services: timeout: 2s retries: 20 - vitess-shadow-5_7: - image: vitess/vttestserver:mysql57@sha256:23863a518b34330109c502ac61a396008f5f023e96263bcb2bb1b0f7f7d5dc7f - restart: unless-stopped - ports: - - 33578:33577 - environment: - PORT: 33574 - KEYSPACES: 'shadow' - NUM_SHARDS: '1' - MYSQL_BIND_HOST: '0.0.0.0' - FOREIGN_KEY_MODE: 'disallow' - ENABLE_ONLINE_DDL: false - MYSQL_MAX_CONNECTIONS: 100000 - TABLET_REFRESH_INTERVAL: '500ms' - healthcheck: - test: ['CMD', 'mysqladmin', 'ping', '-h127.0.0.1', '-P33577'] - interval: 5s - timeout: 2s - retries: 20 - vitess-shadow-8_0: image: vitess/vttestserver:mysql80@sha256:8bec2644d83cb322eb2cdd596d33c0f858243ba6ade9164c95dfcc519643094e restart: unless-stopped diff --git a/flake.lock b/flake.lock index c2750d0435ed..b887051dac9b 100644 --- a/flake.lock +++ b/flake.lock @@ -2,23 +2,16 @@ "nodes": { "crane": { "inputs": { - "flake-compat": "flake-compat", - "flake-utils": [ - "flake-utils" - ], "nixpkgs": [ "nixpkgs" - ], - "rust-overlay": [ - "rust-overlay" ] }, "locked": { - "lastModified": 1696384830, - "narHash": "sha256-j8ZsVqzmj5sOm5MW9cqwQJUZELFFwOislDmqDDEMl6k=", + "lastModified": 1699548976, + "narHash": "sha256-xnpxms0koM8mQpxIup9JnT0F7GrKdvv0QvtxvRuOYR4=", "owner": "ipetkov", "repo": "crane", - "rev": "f2143cd27f8bd09ee4f0121336c65015a2a0a19c", + "rev": "6849911446e18e520970cc6b7a691e64ee90d649", "type": "github" }, "original": { @@ -27,22 +20,6 @@ "type": "github" } }, - "flake-compat": { - "flake": false, - "locked": { - "lastModified": 1696267196, - "narHash": "sha256-AAQ/2sD+0D18bb8hKuEEVpHUYD1GmO2Uh/taFamn6XQ=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "4f910c9827911b1ec2bf26b5a062cd09f8d89f85", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, "flake-parts": { "inputs": { "nixpkgs-lib": [ @@ -50,11 +27,11 @@ ] }, "locked": { - "lastModified": 1696343447, - "narHash": "sha256-B2xAZKLkkeRFG5XcHHSXXcP7To9Xzr59KXeZiRf4vdQ=", + "lastModified": 1698882062, + "narHash": "sha256-HkhafUayIqxXyHH1X8d9RDl1M2CkFgZLjKD3MzabiEo=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "c9afaba3dfa4085dbd2ccb38dfade5141e33d9d4", + "rev": "8c9fa2545007b49a5db5f650ae91f227672c3877", "type": "github" }, "original": { @@ -105,11 +82,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1696193975, - "narHash": "sha256-mnQjUcYgp9Guu3RNVAB2Srr1TqKcPpRXmJf4LJk6KRY=", + "lastModified": 1699963925, + "narHash": "sha256-LE7OV/SwkIBsCpAlIPiFhch/J+jBDGEZjNfdnzCnCrY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "fdd898f8f79e8d2f99ed2ab6b3751811ef683242", + "rev": "bf744fe90419885eefced41b3e5ae442d732712d", "type": "github" }, "original": { @@ -139,11 +116,11 @@ ] }, "locked": { - "lastModified": 1696558324, - "narHash": "sha256-TnnP4LGwDB8ZGE7h2n4nA9Faee8xPkMdNcyrzJ57cbw=", + "lastModified": 1700187354, + "narHash": "sha256-RRIVKv+tiI1yn1PqZiVGQ9YlQGZ+/9iEkA4rst1QiNk=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "fdb37574a04df04aaa8cf7708f94a9309caebe2b", + "rev": "e3ebc177291f5de627d6dfbac817b4a661b15d1c", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 67f4042d8c68..e62a09803d3d 100644 --- a/flake.nix +++ b/flake.nix @@ -3,8 +3,6 @@ crane = { url = "github:ipetkov/crane"; inputs.nixpkgs.follows = "nixpkgs"; - inputs.rust-overlay.follows = "rust-overlay"; - inputs.flake-utils.follows = "flake-utils"; }; flake-utils = { url = "github:numtide/flake-utils"; diff --git a/libs/user-facing-errors/Cargo.toml b/libs/user-facing-errors/Cargo.toml index 9900892209c6..3049a19712b1 100644 --- a/libs/user-facing-errors/Cargo.toml +++ b/libs/user-facing-errors/Cargo.toml @@ -11,7 +11,7 @@ backtrace = "0.3.40" tracing = "0.1" indoc.workspace = true itertools = "0.10" -quaint = { workspace = true, optional = true } +quaint = { path = "../../quaint", optional = true } [features] default = [] diff --git a/prisma-schema-wasm/Cargo.toml b/prisma-schema-wasm/Cargo.toml index 248c726c9ba4..51638e55b1c1 100644 --- a/prisma-schema-wasm/Cargo.toml +++ b/prisma-schema-wasm/Cargo.toml @@ -7,6 +7,6 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -wasm-bindgen = "=0.2.87" +wasm-bindgen = "=0.2.88" wasm-logger = { version = "0.2.0", optional = true } prisma-fmt = { path = "../prisma-fmt" } diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b699518d0910..52a7edf72aca 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -23,20 +23,28 @@ resolver = "2" features = ["docs", "all"] [features] -default = [] +default = ["mysql", "postgresql", "mssql", "sqlite"] docs = [] # Expose the underlying database drivers when a connector is enabled. This is a # way to access database-specific methods when you need extra control. expose-drivers = [] -all = ["mssql", "mysql", "pooled", "postgresql", "sqlite"] +native = [ + "postgresql-native", + "mysql-native", + "mssql-native", + "sqlite-native", +] + +all = ["native", "pooled"] vendored-openssl = [ "postgres-native-tls/vendored-openssl", "mysql_async/vendored-openssl", ] -postgresql = [ +postgresql-native = [ + "postgresql", "native-tls", "tokio-postgres", "postgres-types", @@ -47,11 +55,24 @@ postgresql = [ "lru-cache", "byteorder", ] +postgresql = [] + +mssql-native = [ + "mssql", + "tiberius", + "tokio-util", + "tokio/time", + "tokio/net", +] +mssql = [] + +mysql-native = ["mysql", "mysql_async", "tokio/time", "lru-cache"] +mysql = ["chrono/std"] -mssql = ["tiberius", "tokio-util", "tokio/time", "tokio/net", "either"] -mysql = ["mysql_async", "tokio/time", "lru-cache"] pooled = ["mobc"] -sqlite = ["rusqlite", "tokio/sync"] +sqlite-native = ["sqlite", "rusqlite/bundled", "tokio/sync"] +sqlite = [] + fmt-sql = ["sqlformat"] [dependencies] @@ -67,7 +88,7 @@ futures = "0.3" url = "2.1" hex = "0.4" -either = { version = "1.6", optional = true } +either = { version = "1.6" } base64 = { version = "0.12.3" } chrono = { version = "0.4", default-features = false, features = ["serde"] } lru-cache = { version = "0.1", optional = true } @@ -88,7 +109,11 @@ paste = "1.0" serde = { version = "1.0", features = ["derive"] } quaint-test-macros = { path = "quaint-test-macros" } quaint-test-setup = { path = "quaint-test-setup" } -tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "time"] } +tokio = { version = "1.0", features = ["macros", "time"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies.getrandom] +version = "0.2" +features = ["js"] [dependencies.byteorder] default-features = false @@ -102,7 +127,7 @@ branch = "vendored-openssl" [dependencies.rusqlite] version = "0.29" -features = ["chrono", "bundled", "column_decltype"] +features = ["chrono", "column_decltype"] optional = true [target.'cfg(not(any(target_os = "macos", target_os = "ios")))'.dependencies.tiberius] diff --git a/quaint/README.md b/quaint/README.md index 92033db269b1..03108d9090d3 100644 --- a/quaint/README.md +++ b/quaint/README.md @@ -16,9 +16,13 @@ Quaint is an abstraction over certain SQL databases. It provides: ### Feature flags - `mysql`: Support for MySQL databases. + - On non-WebAssembly targets, choose `mysql-native` instead. - `postgresql`: Support for PostgreSQL databases. + - On non-WebAssembly targets, choose `postgresql-native` instead. - `sqlite`: Support for SQLite databases. + - On non-WebAssembly targets, choose `sqlite-native` instead. - `mssql`: Support for Microsoft SQL Server databases. + - On non-WebAssembly targets, choose `mssql-native` instead. - `pooled`: A connection pool in `pooled::Quaint`. - `vendored-openssl`: Statically links against a vendored OpenSSL library on non-Windows or non-Apple platforms. diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index de8bc64d22bb..dddb3c953ad7 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -10,37 +10,49 @@ //! querying interface. mod connection_info; + pub mod metrics; mod queryable; mod result_set; -#[cfg(any(feature = "mssql", feature = "postgresql", feature = "mysql"))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] mod timeout; mod transaction; mod type_identifier; -#[cfg(feature = "mssql")] -pub(crate) mod mssql; -#[cfg(feature = "mysql")] -pub(crate) mod mysql; -#[cfg(feature = "postgresql")] -pub(crate) mod postgres; -#[cfg(feature = "sqlite")] -pub(crate) mod sqlite; - -#[cfg(feature = "mysql")] -pub use self::mysql::*; -#[cfg(feature = "postgresql")] -pub use self::postgres::*; pub use self::result_set::*; pub use connection_info::*; -#[cfg(feature = "mssql")] -pub use mssql::*; pub use queryable::*; -#[cfg(feature = "sqlite")] -pub use sqlite::*; pub use transaction::*; -#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgresql"))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] #[allow(unused_imports)] pub(crate) use type_identifier::*; pub use self::metrics::query; + +#[cfg(feature = "postgresql")] +pub(crate) mod postgres; +#[cfg(feature = "postgresql-native")] +pub use postgres::native::*; +#[cfg(feature = "postgresql")] +pub use postgres::*; + +#[cfg(feature = "mysql")] +pub(crate) mod mysql; +#[cfg(feature = "mysql-native")] +pub use mysql::native::*; +#[cfg(feature = "mysql")] +pub use mysql::*; + +#[cfg(feature = "sqlite")] +pub(crate) mod sqlite; +#[cfg(feature = "sqlite-native")] +pub use sqlite::native::*; +#[cfg(feature = "sqlite")] +pub use sqlite::*; + +#[cfg(feature = "mssql")] +pub(crate) mod mssql; +#[cfg(feature = "mssql-native")] +pub use mssql::native::*; +#[cfg(feature = "mssql")] +pub use mssql::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index cef092edb9d7..5a493ba17b24 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,614 +1,8 @@ -mod conversion; -mod error; +//! Wasm-compatible definitions for the MSSQL connector. +//! This module is only available with the `mssql` feature. +pub(crate) mod url; -use super::{IsolationLevel, Transaction, TransactionOptions}; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use connection_string::JdbcString; -use futures::lock::Mutex; -use std::{ - convert::TryFrom, - fmt, - future::Future, - str::FromStr, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tiberius::*; -use tokio::net::TcpStream; -use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; +pub use self::url::*; -/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tiberius; - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct MssqlUrl { - connection_string: String, - query_params: MssqlQueryParams, -} - -/// TLS mode when connecting to SQL Server. -#[derive(Debug, Clone, Copy)] -pub enum EncryptMode { - /// All traffic is encrypted. - On, - /// Only the login credentials are encrypted. - Off, - /// Nothing is encrypted. - DangerPlainText, -} - -impl fmt::Display for EncryptMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::On => write!(f, "true"), - Self::Off => write!(f, "false"), - Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), - } - } -} - -impl FromStr for EncryptMode { - type Err = Error; - - fn from_str(s: &str) -> crate::Result { - let mode = match s.parse::() { - Ok(true) => Self::On, - _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, - _ => Self::Off, - }; - - Ok(mode) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MssqlQueryParams { - encrypt: EncryptMode, - port: Option, - host: Option, - user: Option, - password: Option, - database: String, - schema: String, - trust_server_certificate: bool, - trust_server_certificate_ca: Option, - connection_limit: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - transaction_isolation_level: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, -} - -static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; - -#[async_trait] -impl TransactionCapable for Mssql { - async fn start_transaction<'a>( - &'a self, - isolation: Option, - ) -> crate::Result> { - // Isolation levels in SQL Server are set on the connection and live until they're changed. - // Always explicitly setting the isolation level each time a tx is started (either to the given value - // or by using the default/connection string value) prevents transactions started on connections from - // the pool to have unexpected isolation levels set. - let isolation = isolation - .or(self.url.query_params.transaction_isolation_level) - .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); - - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) - } -} - -impl MssqlUrl { - /// Maximum number of connections the pool can have (if used together with - /// pooled Quaint). - pub fn connection_limit(&self) -> Option { - self.query_params.connection_limit() - } - - /// A duration how long one query can take. - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout() - } - - /// A duration how long we can try to connect to the database. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout() - } - - /// A pool check_out timeout. - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout() - } - - /// The isolation level of a transaction. - fn transaction_isolation_level(&self) -> Option { - self.query_params.transaction_isolation_level - } - - /// Name of the database. - pub fn dbname(&self) -> &str { - self.query_params.database() - } - - /// The prefix which to use when querying database. - pub fn schema(&self) -> &str { - self.query_params.schema() - } - - /// Database hostname. - pub fn host(&self) -> &str { - self.query_params.host() - } - - /// The username to use when connecting to the database. - pub fn username(&self) -> Option<&str> { - self.query_params.user() - } - - /// The password to use when connecting to the database. - pub fn password(&self) -> Option<&str> { - self.query_params.password() - } - - /// The TLS mode to use when connecting to the database. - pub fn encrypt(&self) -> EncryptMode { - self.query_params.encrypt() - } - - /// If true, we allow invalid certificates (self-signed, or otherwise - /// dangerous) when connecting. Should be true only for development and - /// testing. - pub fn trust_server_certificate(&self) -> bool { - self.query_params.trust_server_certificate() - } - - /// Path to a custom server certificate file. - pub fn trust_server_certificate_ca(&self) -> Option<&str> { - self.query_params.trust_server_certificate_ca() - } - - /// Database port. - pub fn port(&self) -> u16 { - self.query_params.port() - } - - /// The JDBC connection string - pub fn connection_string(&self) -> &str { - &self.connection_string - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime() - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime() - } -} - -impl MssqlQueryParams { - fn port(&self) -> u16 { - self.port.unwrap_or(1433) - } - - fn host(&self) -> &str { - self.host.as_deref().unwrap_or("localhost") - } - - fn user(&self) -> Option<&str> { - self.user.as_deref() - } - - fn password(&self) -> Option<&str> { - self.password.as_deref() - } - - fn encrypt(&self) -> EncryptMode { - self.encrypt - } - - fn trust_server_certificate(&self) -> bool { - self.trust_server_certificate - } - - fn trust_server_certificate_ca(&self) -> Option<&str> { - self.trust_server_certificate_ca.as_deref() - } - - fn database(&self) -> &str { - &self.database - } - - fn schema(&self) -> &str { - &self.schema - } - - fn socket_timeout(&self) -> Option { - self.socket_timeout - } - - fn connect_timeout(&self) -> Option { - self.connect_timeout - } - - fn connection_limit(&self) -> Option { - self.connection_limit - } - - fn pool_timeout(&self) -> Option { - self.pool_timeout - } - - fn max_connection_lifetime(&self) -> Option { - self.max_connection_lifetime - } - - fn max_idle_connection_lifetime(&self) -> Option { - self.max_idle_connection_lifetime - } -} - -/// A connector interface for the SQL Server database. -#[derive(Debug)] -pub struct Mssql { - client: Mutex>>, - url: MssqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, -} - -impl Mssql { - /// Creates a new connection to SQL Server. - pub async fn new(url: MssqlUrl) -> crate::Result { - let config = Config::from_jdbc_string(&url.connection_string)?; - let tcp = TcpStream::connect_named(&config).await?; - let socket_timeout = url.socket_timeout(); - - let connecting = async { - match Client::connect(config, tcp.compat_write()).await { - Ok(client) => Ok(client), - Err(tiberius::error::Error::Routing { host, port }) => { - let mut config = Config::from_jdbc_string(&url.connection_string)?; - config.host(host); - config.port(port); - - let tcp = TcpStream::connect_named(&config).await?; - Client::connect(config, tcp.compat_write()).await - } - Err(e) => Err(e), - } - }; - - let client = super::timeout::connect(url.connect_timeout(), connecting).await?; - - let this = Self { - client: Mutex::new(client), - url, - socket_timeout, - is_healthy: AtomicBool::new(true), - }; - - if let Some(isolation) = this.url.transaction_isolation_level() { - this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) - .await?; - }; - - Ok(this) - } - - /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. - /// This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &Mutex>> { - &self.client - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } -} - -#[async_trait] -impl Queryable for Mssql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.query_raw(&sql, ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.query_raw", sql, params, move || async move { - let mut client = self.client.lock().await; - - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; - - match results.pop() { - Some(rows) => { - let mut columns_set = false; - let mut columns = Vec::new(); - let mut result_rows = Vec::with_capacity(rows.len()); - - for row in rows.into_iter() { - if !columns_set { - columns = row.columns().iter().map(|c| c.name().to_string()).collect(); - columns_set = true; - } - - let mut values: Vec> = Vec::with_capacity(row.len()); - - for val in row.into_iter() { - values.push(Value::try_from(val)?); - } - - result_rows.push(values); - } - - Ok(ResultSet::new(columns, result_rows)) - } - None => Ok(ResultSet::new(Vec::new(), Vec::new())), - } - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.execute_raw(&sql, ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.execute_raw", sql, params, move || async move { - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut client = self.client.lock().await; - let changes = self.perform_io(query.execute(&mut client)).await?.total(); - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mssql.raw_cmd", cmd, &[], move || async move { - let mut client = self.client.lock().await; - self.perform_io(client.simple_query(cmd)).await?.into_results().await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@VERSION AS version"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -impl MssqlUrl { - pub fn new(jdbc_connection_string: &str) -> crate::Result { - let query_params = Self::parse_query_params(jdbc_connection_string)?; - let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); - - Ok(Self { - connection_string, - query_params, - }) - } - - fn with_jdbc_prefix(input: &str) -> String { - if input.starts_with("jdbc:sqlserver") { - input.into() - } else { - format!("jdbc:{input}") - } - } - - fn parse_query_params(input: &str) -> crate::Result { - let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; - - let host = conn.server_name().map(|server_name| match conn.instance_name() { - Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), - None => server_name.to_string(), - }); - - let port = conn.port(); - let props = conn.properties_mut(); - let user = props.remove("user"); - let password = props.remove("password"); - let database = props.remove("database").unwrap_or_else(|| String::from("master")); - let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); - - let connection_limit = props - .remove("connectionlimit") - .or_else(|| props.remove("connection_limit")) - .map(|param| param.parse()) - .transpose()?; - - let transaction_isolation_level = props - .remove("isolationlevel") - .or_else(|| props.remove("isolation_level")) - .map(|level| { - IsolationLevel::from_str(&level).map_err(|_| { - let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); - Error::builder(kind).build() - }) - }) - .transpose()?; - - let mut connect_timeout = props - .remove("logintimeout") - .or_else(|| props.remove("login_timeout")) - .or_else(|| props.remove("connecttimeout")) - .or_else(|| props.remove("connect_timeout")) - .or_else(|| props.remove("connectiontimeout")) - .or_else(|| props.remove("connection_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match connect_timeout { - None => connect_timeout = Some(Duration::from_secs(5)), - Some(dur) if dur.as_secs() == 0 => connect_timeout = None, - _ => (), - } - - let mut pool_timeout = props - .remove("pooltimeout") - .or_else(|| props.remove("pool_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match pool_timeout { - None => pool_timeout = Some(Duration::from_secs(10)), - Some(dur) if dur.as_secs() == 0 => pool_timeout = None, - _ => (), - } - - let socket_timeout = props - .remove("sockettimeout") - .or_else(|| props.remove("socket_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - let encrypt = props - .remove("encrypt") - .map(|param| EncryptMode::from_str(¶m)) - .transpose()? - .unwrap_or(EncryptMode::On); - - let trust_server_certificate = props - .remove("trustservercertificate") - .or_else(|| props.remove("trust_server_certificate")) - .map(|param| param.parse()) - .transpose()? - .unwrap_or(false); - - let trust_server_certificate_ca: Option = props - .remove("trustservercertificateca") - .or_else(|| props.remove("trust_server_certificate_ca")); - - let mut max_connection_lifetime = props - .remove("max_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_connection_lifetime { - Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, - _ => (), - } - - let mut max_idle_connection_lifetime = props - .remove("max_idle_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_idle_connection_lifetime { - None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), - Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, - _ => (), - } - - Ok(MssqlQueryParams { - encrypt, - port, - host, - user, - password, - database, - schema, - trust_server_certificate, - trust_server_certificate_ca, - connection_limit, - socket_timeout, - connect_timeout, - pool_timeout, - transaction_isolation_level, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } -} - -#[cfg(test)] -mod tests { - use crate::tests::test_api::mssql::CONN_STR; - use crate::{error::*, single::Quaint}; - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let url = CONN_STR.replace("user=SA", "user=WRONG"); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mssql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/mssql/conversion.rs b/quaint/src/connector/mssql/native/conversion.rs similarity index 100% rename from quaint/src/connector/mssql/conversion.rs rename to quaint/src/connector/mssql/native/conversion.rs diff --git a/quaint/src/connector/mssql/error.rs b/quaint/src/connector/mssql/native/error.rs similarity index 100% rename from quaint/src/connector/mssql/error.rs rename to quaint/src/connector/mssql/native/error.rs diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs new file mode 100644 index 000000000000..d22aa7a15dd6 --- /dev/null +++ b/quaint/src/connector/mssql/native/mod.rs @@ -0,0 +1,239 @@ +//! Definitions for the MSSQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mssql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::mssql::MssqlUrl; +use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::lock::Mutex; +use std::{ + convert::TryFrom, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tiberius::*; +use tokio::net::TcpStream; +use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; + +/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tiberius; + +static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; + +#[async_trait] +impl TransactionCapable for Mssql { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result> { + // Isolation levels in SQL Server are set on the connection and live until they're changed. + // Always explicitly setting the isolation level each time a tx is started (either to the given value + // or by using the default/connection string value) prevents transactions started on connections from + // the pool to have unexpected isolation levels set. + let isolation = isolation + .or(self.url.query_params.transaction_isolation_level) + .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); + + let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + + Ok(Box::new( + DefaultTransaction::new(self, self.begin_statement(), opts).await?, + )) + } +} + +/// A connector interface for the SQL Server database. +#[derive(Debug)] +pub struct Mssql { + client: Mutex>>, + url: MssqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, +} + +impl Mssql { + /// Creates a new connection to SQL Server. + pub async fn new(url: MssqlUrl) -> crate::Result { + let config = Config::from_jdbc_string(&url.connection_string)?; + let tcp = TcpStream::connect_named(&config).await?; + let socket_timeout = url.socket_timeout(); + + let connecting = async { + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(tiberius::error::Error::Routing { host, port }) => { + let mut config = Config::from_jdbc_string(&url.connection_string)?; + config.host(host); + config.port(port); + + let tcp = TcpStream::connect_named(&config).await?; + Client::connect(config, tcp.compat_write()).await + } + Err(e) => Err(e), + } + }; + + let client = timeout::connect(url.connect_timeout(), connecting).await?; + + let this = Self { + client: Mutex::new(client), + url, + socket_timeout, + is_healthy: AtomicBool::new(true), + }; + + if let Some(isolation) = this.url.transaction_isolation_level() { + this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) + .await?; + }; + + Ok(this) + } + + /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. + /// This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &Mutex>> { + &self.client + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } +} + +#[async_trait] +impl Queryable for Mssql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.query_raw(&sql, ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.query_raw", sql, params, move || async move { + let mut client = self.client.lock().await; + + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; + + match results.pop() { + Some(rows) => { + let mut columns_set = false; + let mut columns = Vec::new(); + let mut result_rows = Vec::with_capacity(rows.len()); + + for row in rows.into_iter() { + if !columns_set { + columns = row.columns().iter().map(|c| c.name().to_string()).collect(); + columns_set = true; + } + + let mut values: Vec> = Vec::with_capacity(row.len()); + + for val in row.into_iter() { + values.push(Value::try_from(val)?); + } + + result_rows.push(values); + } + + Ok(ResultSet::new(columns, result_rows)) + } + None => Ok(ResultSet::new(Vec::new(), Vec::new())), + } + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.execute_raw(&sql, ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.execute_raw", sql, params, move || async move { + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut client = self.client.lock().await; + let changes = self.perform_io(query.execute(&mut client)).await?.total(); + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + let mut client = self.client.lock().await; + self.perform_io(client.simple_query(cmd)).await?.into_results().await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@VERSION AS version"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn begin_statement(&self) -> &'static str { + "BEGIN TRAN" + } + + fn requires_isolation_first(&self) -> bool { + true + } +} diff --git a/quaint/src/connector/mssql/url.rs b/quaint/src/connector/mssql/url.rs new file mode 100644 index 000000000000..42cc0868f9bf --- /dev/null +++ b/quaint/src/connector/mssql/url.rs @@ -0,0 +1,384 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::{ + connector::IsolationLevel, + error::{Error, ErrorKind}, +}; +use connection_string::JdbcString; +use std::{fmt, str::FromStr, time::Duration}; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct MssqlUrl { + pub(crate) connection_string: String, + pub(crate) query_params: MssqlQueryParams, +} + +/// TLS mode when connecting to SQL Server. +#[derive(Debug, Clone, Copy)] +pub enum EncryptMode { + /// All traffic is encrypted. + On, + /// Only the login credentials are encrypted. + Off, + /// Nothing is encrypted. + DangerPlainText, +} + +impl fmt::Display for EncryptMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::On => write!(f, "true"), + Self::Off => write!(f, "false"), + Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), + } + } +} + +impl FromStr for EncryptMode { + type Err = Error; + + fn from_str(s: &str) -> crate::Result { + let mode = match s.parse::() { + Ok(true) => Self::On, + _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, + _ => Self::Off, + }; + + Ok(mode) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MssqlQueryParams { + pub(crate) encrypt: EncryptMode, + pub(crate) port: Option, + pub(crate) host: Option, + pub(crate) user: Option, + pub(crate) password: Option, + pub(crate) database: String, + pub(crate) schema: String, + pub(crate) trust_server_certificate: bool, + pub(crate) trust_server_certificate_ca: Option, + pub(crate) connection_limit: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) transaction_isolation_level: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, +} + +impl MssqlUrl { + /// Maximum number of connections the pool can have (if used together with + /// pooled Quaint). + pub fn connection_limit(&self) -> Option { + self.query_params.connection_limit() + } + + /// A duration how long one query can take. + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout() + } + + /// A duration how long we can try to connect to the database. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout() + } + + /// A pool check_out timeout. + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout() + } + + /// The isolation level of a transaction. + pub(crate) fn transaction_isolation_level(&self) -> Option { + self.query_params.transaction_isolation_level + } + + /// Name of the database. + pub fn dbname(&self) -> &str { + self.query_params.database() + } + + /// The prefix which to use when querying database. + pub fn schema(&self) -> &str { + self.query_params.schema() + } + + /// Database hostname. + pub fn host(&self) -> &str { + self.query_params.host() + } + + /// The username to use when connecting to the database. + pub fn username(&self) -> Option<&str> { + self.query_params.user() + } + + /// The password to use when connecting to the database. + pub fn password(&self) -> Option<&str> { + self.query_params.password() + } + + /// The TLS mode to use when connecting to the database. + pub fn encrypt(&self) -> EncryptMode { + self.query_params.encrypt() + } + + /// If true, we allow invalid certificates (self-signed, or otherwise + /// dangerous) when connecting. Should be true only for development and + /// testing. + pub fn trust_server_certificate(&self) -> bool { + self.query_params.trust_server_certificate() + } + + /// Path to a custom server certificate file. + pub fn trust_server_certificate_ca(&self) -> Option<&str> { + self.query_params.trust_server_certificate_ca() + } + + /// Database port. + pub fn port(&self) -> u16 { + self.query_params.port() + } + + /// The JDBC connection string + pub fn connection_string(&self) -> &str { + &self.connection_string + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime() + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime() + } +} + +impl MssqlQueryParams { + fn port(&self) -> u16 { + self.port.unwrap_or(1433) + } + + fn host(&self) -> &str { + self.host.as_deref().unwrap_or("localhost") + } + + fn user(&self) -> Option<&str> { + self.user.as_deref() + } + + fn password(&self) -> Option<&str> { + self.password.as_deref() + } + + fn encrypt(&self) -> EncryptMode { + self.encrypt + } + + fn trust_server_certificate(&self) -> bool { + self.trust_server_certificate + } + + fn trust_server_certificate_ca(&self) -> Option<&str> { + self.trust_server_certificate_ca.as_deref() + } + + fn database(&self) -> &str { + &self.database + } + + fn schema(&self) -> &str { + &self.schema + } + + fn socket_timeout(&self) -> Option { + self.socket_timeout + } + + fn connect_timeout(&self) -> Option { + self.connect_timeout + } + + fn connection_limit(&self) -> Option { + self.connection_limit + } + + fn pool_timeout(&self) -> Option { + self.pool_timeout + } + + fn max_connection_lifetime(&self) -> Option { + self.max_connection_lifetime + } + + fn max_idle_connection_lifetime(&self) -> Option { + self.max_idle_connection_lifetime + } +} + +impl MssqlUrl { + pub fn new(jdbc_connection_string: &str) -> crate::Result { + let query_params = Self::parse_query_params(jdbc_connection_string)?; + let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); + + Ok(Self { + connection_string, + query_params, + }) + } + + fn with_jdbc_prefix(input: &str) -> String { + if input.starts_with("jdbc:sqlserver") { + input.into() + } else { + format!("jdbc:{input}") + } + } + + fn parse_query_params(input: &str) -> crate::Result { + let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; + + let host = conn.server_name().map(|server_name| match conn.instance_name() { + Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), + None => server_name.to_string(), + }); + + let port = conn.port(); + let props = conn.properties_mut(); + let user = props.remove("user"); + let password = props.remove("password"); + let database = props.remove("database").unwrap_or_else(|| String::from("master")); + let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); + + let connection_limit = props + .remove("connectionlimit") + .or_else(|| props.remove("connection_limit")) + .map(|param| param.parse()) + .transpose()?; + + let transaction_isolation_level = props + .remove("isolationlevel") + .or_else(|| props.remove("isolation_level")) + .map(|level| { + IsolationLevel::from_str(&level).map_err(|_| { + let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); + Error::builder(kind).build() + }) + }) + .transpose()?; + + let mut connect_timeout = props + .remove("logintimeout") + .or_else(|| props.remove("login_timeout")) + .or_else(|| props.remove("connecttimeout")) + .or_else(|| props.remove("connect_timeout")) + .or_else(|| props.remove("connectiontimeout")) + .or_else(|| props.remove("connection_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match connect_timeout { + None => connect_timeout = Some(Duration::from_secs(5)), + Some(dur) if dur.as_secs() == 0 => connect_timeout = None, + _ => (), + } + + let mut pool_timeout = props + .remove("pooltimeout") + .or_else(|| props.remove("pool_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match pool_timeout { + None => pool_timeout = Some(Duration::from_secs(10)), + Some(dur) if dur.as_secs() == 0 => pool_timeout = None, + _ => (), + } + + let socket_timeout = props + .remove("sockettimeout") + .or_else(|| props.remove("socket_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + let encrypt = props + .remove("encrypt") + .map(|param| EncryptMode::from_str(¶m)) + .transpose()? + .unwrap_or(EncryptMode::On); + + let trust_server_certificate = props + .remove("trustservercertificate") + .or_else(|| props.remove("trust_server_certificate")) + .map(|param| param.parse()) + .transpose()? + .unwrap_or(false); + + let trust_server_certificate_ca: Option = props + .remove("trustservercertificateca") + .or_else(|| props.remove("trust_server_certificate_ca")); + + let mut max_connection_lifetime = props + .remove("max_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_connection_lifetime { + Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, + _ => (), + } + + let mut max_idle_connection_lifetime = props + .remove("max_idle_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_idle_connection_lifetime { + None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), + Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, + _ => (), + } + + Ok(MssqlQueryParams { + encrypt, + port, + host, + user, + password, + database, + schema, + trust_server_certificate, + trust_server_certificate_ca, + connection_limit, + socket_timeout, + connect_timeout, + pool_timeout, + transaction_isolation_level, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::test_api::mssql::CONN_STR; + use crate::{error::*, single::Quaint}; + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let url = CONN_STR.replace("user=SA", "user=WRONG"); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 4b6f27a583da..77bb6e0d1b8a 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,669 +1,10 @@ -mod conversion; -mod error; - -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use lru_cache::LruCache; -use mysql_async::{ - self as my, - prelude::{Query as _, Queryable as _}, -}; -use percent_encoding::percent_decode; -use std::{ - borrow::Cow, - future::Future, - path::{Path, PathBuf}, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio::sync::Mutex; -use url::{Host, Url}; +//! Wasm-compatible definitions for the MySQL connector. +//! This module is only available with the `mysql` feature. +pub(crate) mod error; +pub(crate) mod url; +pub use self::url::*; pub use error::MysqlError; -/// The underlying MySQL driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use mysql_async; - -use super::IsolationLevel; - -/// A connector interface for the MySQL database. -#[derive(Debug)] -pub struct Mysql { - pub(crate) conn: Mutex, - pub(crate) url: MysqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, - statement_cache: Mutex>, -} - -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct MysqlUrl { - url: Url, - query_params: MysqlUrlQueryParams, -} - -impl MysqlUrl { - /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Option> { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => Some(password), - None => self.url.password().map(|s| s.into()), - } - } - - /// Name of the database connected. Defaults to `mysql`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("mysql"), - None => "mysql", - } - } - - /// The database host. If `socket` and `host` are not set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.url.host(), self.url.host_str()) { - (Some(Host::Ipv6(_)), Some(host)) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (_, Some(host)) => host, - _ => "localhost", - } - } - - /// If set, connected to the database through a Unix socket. - pub fn socket(&self) -> &Option { - &self.query_params.socket - } - - /// The database port, defaults to `3306`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(3306) - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// The pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// Prefer socket connection - pub fn prefer_socket(&self) -> Option { - self.query_params.prefer_socket - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - fn statement_cache_size(&self) -> usize { - self.query_params.statement_cache_size - } - - pub(crate) fn cache(&self) -> LruCache { - LruCache::new(self.query_params.statement_cache_size) - } - - fn parse_query_params(url: &Url) -> Result { - let mut ssl_opts = my::SslOpts::default(); - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); - - let mut connection_limit = None; - let mut use_ssl = false; - let mut socket = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut prefer_socket = None; - let mut statement_cache_size = 100; - let mut identity: Option<(Option, Option)> = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslcert" => { - use_ssl = true; - ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); - } - "sslidentity" => { - use_ssl = true; - - identity = match identity { - Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), - None => Some((Some(Path::new(&*v).to_path_buf()), None)), - }; - } - "sslpassword" => { - use_ssl = true; - - identity = match identity { - Some((path, _)) => Some((path, Some(v.to_string()))), - None => Some((None, Some(v.to_string()))), - }; - } - "socket" => { - socket = Some(v.replace(['(', ')'], "")); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "prefer_socket" => { - let as_bool = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - prefer_socket = Some(as_bool) - } - "connect_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connect_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "pool_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - pool_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "sslaccept" => { - use_ssl = true; - match v.as_ref() { - "strict" => { - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); - } - "accept_invalid_certs" => {} - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", - mode = &*v - ); - } - }; - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - ssl_opts = match identity { - Some((Some(path), Some(pw))) => { - let identity = mysql_async::ClientIdentity::new(path).with_password(pw); - ssl_opts.with_client_identity(Some(identity)) - } - Some((Some(path), None)) => { - let identity = mysql_async::ClientIdentity::new(path); - ssl_opts.with_client_identity(Some(identity)) - } - _ => ssl_opts, - }; - - Ok(MysqlUrlQueryParams { - ssl_opts, - connection_limit, - use_ssl, - socket, - socket_timeout, - connect_timeout, - pool_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - prefer_socket, - statement_cache_size, - }) - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { - let mut config = my::OptsBuilder::default() - .stmt_cache_size(Some(0)) - .user(Some(self.username())) - .pass(self.password()) - .db_name(Some(self.dbname())); - - match self.socket() { - Some(ref socket) => { - config = config.socket(Some(socket)); - } - None => { - config = config.ip_or_hostname(self.host()).tcp_port(self.port()); - } - } - - config = config.conn_ttl(Some(Duration::from_secs(5))); - - if self.query_params.use_ssl { - config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); - } - - if self.query_params.prefer_socket.is_some() { - config = config.prefer_socket(self.query_params.prefer_socket); - } - - config - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - ssl_opts: my::SslOpts, - connection_limit: Option, - use_ssl: bool, - socket: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - prefer_socket: Option, - statement_cache_size: usize, -} - -impl Mysql { - /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. - pub async fn new(url: MysqlUrl) -> crate::Result { - let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; - - Ok(Self { - socket_timeout: url.query_params.socket_timeout, - conn: Mutex::new(conn), - statement_cache: Mutex::new(url.cache()), - url, - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying mysql_async::Conn. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn conn(&self) -> &Mutex { - &self.conn - } - - async fn perform_io(&self, op: U) -> crate::Result - where - F: Future>, - U: FnOnce() -> F, - { - match super::timeout::socket(self.socket_timeout, op()).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => Ok(res?), - } - } - - async fn prepared(&self, sql: &str, op: U) -> crate::Result - where - F: Future>, - U: Fn(my::Statement) -> F, - { - if self.url.statement_cache_size() == 0 { - self.perform_io(|| async move { - let stmt = { - let mut conn = self.conn.lock().await; - conn.prep(sql).await? - }; - - let res = op(stmt.clone()).await; - - { - let mut conn = self.conn.lock().await; - conn.close(stmt).await?; - } - - res - }) - .await - } else { - self.perform_io(|| async move { - let stmt = self.fetch_cached(sql).await?; - op(stmt).await - }) - .await - } - } - - async fn fetch_cached(&self, sql: &str) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let mut conn = self.conn.lock().await; - if cache.capacity() == cache.len() { - if let Some((_, stmt)) = cache.remove_lru() { - conn.close(stmt).await?; - } - } - - let stmt = conn.prep(sql).await?; - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } -} - -impl_default_TransactionCapable!(Mysql); - -#[async_trait] -impl Queryable for Mysql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.query_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; - let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); - - let last_id = conn.last_insert_id(); - let mut result_set = ResultSet::new(columns, Vec::new()); - - for mut row in rows { - result_set.rows.push(row.take_result_row()?); - } - - if let Some(id) = last_id { - result_set.set_last_insert_id(id); - }; - - Ok(result_set) - }) - .await - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.execute_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - conn.exec_drop(stmt, conversion::conv_params(params)?).await?; - - Ok(conn.affected_rows()) - }) - .await - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mysql.raw_cmd", cmd, &[], move || async move { - self.perform_io(|| async move { - let mut conn = self.conn.lock().await; - let mut result = cmd.run(&mut *conn).await?; - - loop { - result.map(drop).await?; - - if result.is_empty() { - result.map(drop).await?; - break; - } - } - - Ok(()) - }) - .await - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@GLOBAL.version version"#; - let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::MysqlUrl; - use crate::tests::test_api::mysql::CONN_STR; - use crate::{error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); - } - - #[test] - fn should_parse_prefer_socket() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); - assert!(!url.prefer_socket().unwrap()); - } - - #[test] - fn should_parse_sslaccept() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); - assert!(url.query_params.use_ssl); - assert!(!url.query_params.ssl_opts.skip_domain_validation()); - assert!(!url.query_params.ssl_opts.accept_invalid_certs()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) - .unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("root").unwrap(); - url.set_path("/this_does_not_exist"); - - let url = url.as_str().to_string(); - let res = Quaint::new(&url).await; - - let err = res.unwrap_err(); - - match err.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("1049"), err.original_code()); - assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mysql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/mysql/error.rs b/quaint/src/connector/mysql/error.rs index dd7c3d3bfa66..7b4813bf0223 100644 --- a/quaint/src/connector/mysql/error.rs +++ b/quaint/src/connector/mysql/error.rs @@ -1,22 +1,23 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; -use mysql_async as my; +use thiserror::Error; + +// This is a partial copy of the `mysql_async::Error` using only the enum variant used by Prisma. +// This avoids pulling in `mysql_async`, which would break Wasm compilation. +#[derive(Debug, Error)] +enum MysqlAsyncError { + #[error("Server error: `{}'", _0)] + Server(#[source] MysqlError), +} +/// This type represents MySql server error. +#[derive(Debug, Error, Clone, Eq, PartialEq)] +#[error("ERROR {} ({}): {}", state, code, message)] pub struct MysqlError { pub code: u16, pub message: String, pub state: String, } -impl From<&my::ServerError> for MysqlError { - fn from(value: &my::ServerError) -> Self { - MysqlError { - code: value.code, - message: value.message.to_owned(), - state: value.state.to_owned(), - } - } -} - impl From for Error { fn from(error: MysqlError) -> Self { let code = error.code; @@ -232,7 +233,7 @@ impl From for Error { } _ => { let kind = ErrorKind::QueryError( - my::Error::Server(my::ServerError { + MysqlAsyncError::Server(MysqlError { message: error.message.clone(), code, state: error.state.clone(), @@ -249,24 +250,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: my::Error) -> Error { - match e { - my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { - message: err.to_string(), - }) - .build(), - my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { - Error::builder(ErrorKind::ConnectionClosed).build() - } - my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), - my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), - my::Error::Server(ref server_error) => { - let mysql_error: MysqlError = server_error.into(); - mysql_error.into() - } - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} diff --git a/quaint/src/connector/mysql/conversion.rs b/quaint/src/connector/mysql/native/conversion.rs similarity index 100% rename from quaint/src/connector/mysql/conversion.rs rename to quaint/src/connector/mysql/native/conversion.rs diff --git a/quaint/src/connector/mysql/native/error.rs b/quaint/src/connector/mysql/native/error.rs new file mode 100644 index 000000000000..89c21fb706f6 --- /dev/null +++ b/quaint/src/connector/mysql/native/error.rs @@ -0,0 +1,36 @@ +use crate::{ + connector::mysql::error::MysqlError, + error::{Error, ErrorKind}, +}; +use mysql_async as my; + +impl From<&my::ServerError> for MysqlError { + fn from(value: &my::ServerError) -> Self { + MysqlError { + code: value.code, + message: value.message.to_owned(), + state: value.state.to_owned(), + } + } +} + +impl From for Error { + fn from(e: my::Error) -> Error { + match e { + my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { + message: err.to_string(), + }) + .build(), + my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + Error::builder(ErrorKind::ConnectionClosed).build() + } + my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), + my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), + my::Error::Server(ref server_error) => { + let mysql_error: MysqlError = server_error.into(); + mysql_error.into() + } + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs new file mode 100644 index 000000000000..fdcc3a6276d1 --- /dev/null +++ b/quaint/src/connector/mysql/native/mod.rs @@ -0,0 +1,297 @@ +//! Definitions for the MySQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mysql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::mysql::MysqlUrl; +use crate::connector::{timeout, IsolationLevel}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use lru_cache::LruCache; +use mysql_async::{ + self as my, + prelude::{Query as _, Queryable as _}, +}; +use std::{ + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio::sync::Mutex; + +/// The underlying MySQL driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use mysql_async; + +impl MysqlUrl { + pub(crate) fn cache(&self) -> LruCache { + LruCache::new(self.query_params.statement_cache_size) + } + + pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { + let mut config = my::OptsBuilder::default() + .stmt_cache_size(Some(0)) + .user(Some(self.username())) + .pass(self.password()) + .db_name(Some(self.dbname())); + + match self.socket() { + Some(ref socket) => { + config = config.socket(Some(socket)); + } + None => { + config = config.ip_or_hostname(self.host()).tcp_port(self.port()); + } + } + + config = config.conn_ttl(Some(Duration::from_secs(5))); + + if self.query_params.use_ssl { + config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); + } + + if self.query_params.prefer_socket.is_some() { + config = config.prefer_socket(self.query_params.prefer_socket); + } + + config + } +} + +/// A connector interface for the MySQL database. +#[derive(Debug)] +pub struct Mysql { + pub(crate) conn: Mutex, + pub(crate) url: MysqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, + statement_cache: Mutex>, +} + +impl Mysql { + /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. + pub async fn new(url: MysqlUrl) -> crate::Result { + let conn = timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; + + Ok(Self { + socket_timeout: url.query_params.socket_timeout, + conn: Mutex::new(conn), + statement_cache: Mutex::new(url.cache()), + url, + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying mysql_async::Conn. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn conn(&self) -> &Mutex { + &self.conn + } + + async fn perform_io(&self, op: U) -> crate::Result + where + F: Future>, + U: FnOnce() -> F, + { + match timeout::socket(self.socket_timeout, op()).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => Ok(res?), + } + } + + async fn prepared(&self, sql: &str, op: U) -> crate::Result + where + F: Future>, + U: Fn(my::Statement) -> F, + { + if self.url.statement_cache_size() == 0 { + self.perform_io(|| async move { + let stmt = { + let mut conn = self.conn.lock().await; + conn.prep(sql).await? + }; + + let res = op(stmt.clone()).await; + + { + let mut conn = self.conn.lock().await; + conn.close(stmt).await?; + } + + res + }) + .await + } else { + self.perform_io(|| async move { + let stmt = self.fetch_cached(sql).await?; + op(stmt).await + }) + .await + } + } + + async fn fetch_cached(&self, sql: &str) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let mut conn = self.conn.lock().await; + if cache.capacity() == cache.len() { + if let Some((_, stmt)) = cache.remove_lru() { + conn.close(stmt).await?; + } + } + + let stmt = conn.prep(sql).await?; + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } +} + +impl_default_TransactionCapable!(Mysql); + +#[async_trait] +impl Queryable for Mysql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.query_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; + let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); + + let last_id = conn.last_insert_id(); + let mut result_set = ResultSet::new(columns, Vec::new()); + + for mut row in rows { + result_set.rows.push(row.take_result_row()?); + } + + if let Some(id) = last_id { + result_set.set_last_insert_id(id); + }; + + Ok(result_set) + }) + .await + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.execute_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + conn.exec_drop(stmt, conversion::conv_params(params)?).await?; + + Ok(conn.affected_rows()) + }) + .await + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mysql.raw_cmd", cmd, &[], move || async move { + self.perform_io(|| async move { + let mut conn = self.conn.lock().await; + let mut result = cmd.run(&mut *conn).await?; + + loop { + result.map(drop).await?; + + if result.is_empty() { + result.map(drop).await?; + break; + } + } + + Ok(()) + }) + .await + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@GLOBAL.version version"#; + let rows = timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + true + } +} diff --git a/quaint/src/connector/mysql/url.rs b/quaint/src/connector/mysql/url.rs new file mode 100644 index 000000000000..f0756fa95833 --- /dev/null +++ b/quaint/src/connector/mysql/url.rs @@ -0,0 +1,401 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::error::{Error, ErrorKind}; +use percent_encoding::percent_decode; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + time::Duration, +}; +use url::{Host, Url}; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct MysqlUrl { + url: Url, + pub(crate) query_params: MysqlUrlQueryParams, +} + +impl MysqlUrl { + /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Option> { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => Some(password), + None => self.url.password().map(|s| s.into()), + } + } + + /// Name of the database connected. Defaults to `mysql`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("mysql"), + None => "mysql", + } + } + + /// The database host. If `socket` and `host` are not set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.url.host(), self.url.host_str()) { + (Some(Host::Ipv6(_)), Some(host)) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (_, Some(host)) => host, + _ => "localhost", + } + } + + /// If set, connected to the database through a Unix socket. + pub fn socket(&self) -> &Option { + &self.query_params.socket + } + + /// The database port, defaults to `3306`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(3306) + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// The pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// Prefer socket connection + pub fn prefer_socket(&self) -> Option { + self.query_params.prefer_socket + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + pub(crate) fn statement_cache_size(&self) -> usize { + self.query_params.statement_cache_size + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "mysql-native")] + let mut ssl_opts = { + let mut ssl_opts = mysql_async::SslOpts::default(); + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); + ssl_opts + }; + + let mut connection_limit = None; + let mut use_ssl = false; + let mut socket = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut prefer_socket = None; + let mut statement_cache_size = 100; + let mut identity: Option<(Option, Option)> = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslcert" => { + use_ssl = true; + + #[cfg(feature = "mysql-native")] + { + ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); + } + } + "sslidentity" => { + use_ssl = true; + + identity = match identity { + Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), + None => Some((Some(Path::new(&*v).to_path_buf()), None)), + }; + } + "sslpassword" => { + use_ssl = true; + + identity = match identity { + Some((path, _)) => Some((path, Some(v.to_string()))), + None => Some((None, Some(v.to_string()))), + }; + } + "socket" => { + socket = Some(v.replace(['(', ')'], "")); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "prefer_socket" => { + let as_bool = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + prefer_socket = Some(as_bool) + } + "connect_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connect_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "pool_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + pool_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "sslaccept" => { + use_ssl = true; + match v.as_ref() { + "strict" => { + #[cfg(feature = "mysql-native")] + { + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); + } + } + "accept_invalid_certs" => {} + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", + mode = &*v + ); + } + }; + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + // Wrapping this in a block, as attributes on expressions are still experimental + // See: https://github.com/rust-lang/rust/issues/15701 + #[cfg(feature = "mysql-native")] + { + ssl_opts = match identity { + Some((Some(path), Some(pw))) => { + let identity = mysql_async::ClientIdentity::new(path).with_password(pw); + ssl_opts.with_client_identity(Some(identity)) + } + Some((Some(path), None)) => { + let identity = mysql_async::ClientIdentity::new(path); + ssl_opts.with_client_identity(Some(identity)) + } + _ => ssl_opts, + }; + } + + Ok(MysqlUrlQueryParams { + #[cfg(feature = "mysql-native")] + ssl_opts, + connection_limit, + use_ssl, + socket, + socket_timeout, + connect_timeout, + pool_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + prefer_socket, + statement_cache_size, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + pub(crate) connection_limit: Option, + pub(crate) use_ssl: bool, + pub(crate) socket: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) prefer_socket: Option, + pub(crate) statement_cache_size: usize, + + #[cfg(feature = "mysql-native")] + pub(crate) ssl_opts: mysql_async::SslOpts, +} + +#[cfg(test)] +mod tests { + use super::MysqlUrl; + use crate::tests::test_api::mysql::CONN_STR; + use crate::{error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); + } + + #[test] + fn should_parse_prefer_socket() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); + assert!(!url.prefer_socket().unwrap()); + } + + #[test] + fn should_parse_sslaccept() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); + assert!(url.query_params.use_ssl); + assert!(!url.query_params.ssl_opts.skip_domain_validation()); + assert!(!url.query_params.ssl_opts.accept_invalid_certs()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) + .unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("root").unwrap(); + url.set_path("/this_does_not_exist"); + + let url = url.as_str().to_string(); + let res = Quaint::new(&url).await; + + let err = res.unwrap_err(); + + match err.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("1049"), err.original_code()); + assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 766be38b27e4..befc980ce29e 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,1593 +1,10 @@ -mod conversion; -mod error; - -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use futures::{future::FutureExt, lock::Mutex}; -use lru_cache::LruCache; -use native_tls::{Certificate, Identity, TlsConnector}; -use percent_encoding::percent_decode; -use postgres_native_tls::MakeTlsConnector; -use std::{ - borrow::{Borrow, Cow}, - fmt::{Debug, Display}, - fs, - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio_postgres::{ - config::{ChannelBinding, SslMode}, - Client, Config, Statement, -}; -use url::{Host, Url}; +//! Wasm-compatible definitions for the PostgreSQL connector. +//! This module is only available with the `postgresql` feature. +pub(crate) mod error; +pub(crate) mod url; +pub use self::url::*; pub use error::PostgresError; -pub(crate) const DEFAULT_SCHEMA: &str = "public"; - -/// The underlying postgres driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tokio_postgres; - -use super::{IsolationLevel, Transaction}; - -#[derive(Clone)] -struct Hidden(T); - -impl Debug for Hidden { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("") - } -} - -struct PostgresClient(Client); - -impl Debug for PostgresClient { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("PostgresClient") - } -} - -/// A connector interface for the PostgreSQL database. -#[derive(Debug)] -pub struct PostgreSql { - client: PostgresClient, - pg_bouncer: bool, - socket_timeout: Option, - statement_cache: Mutex>, - is_healthy: AtomicBool, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SslAcceptMode { - Strict, - AcceptInvalidCerts, -} - -#[derive(Debug, Clone)] -pub struct SslParams { - certificate_file: Option, - identity_file: Option, - identity_password: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -#[derive(Debug)] -struct SslAuth { - certificate: Hidden>, - identity: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -impl Default for SslAuth { - fn default() -> Self { - Self { - certificate: Hidden(None), - identity: Hidden(None), - ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, - } - } -} - -impl SslAuth { - fn certificate(&mut self, certificate: Certificate) -> &mut Self { - self.certificate = Hidden(Some(certificate)); - self - } - - fn identity(&mut self, identity: Identity) -> &mut Self { - self.identity = Hidden(Some(identity)); - self - } - - fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { - self.ssl_accept_mode = mode; - self - } -} - -impl SslParams { - async fn into_auth(self) -> crate::Result { - let mut auth = SslAuth::default(); - auth.accept_mode(self.ssl_accept_mode); - - if let Some(ref cert_file) = self.certificate_file { - let cert = fs::read(cert_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("cert file not found ({err})"), - }) - .build() - })?; - - auth.certificate(Certificate::from_pem(&cert)?); - } - - if let Some(ref identity_file) = self.identity_file { - let db = fs::read(identity_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("identity file not found ({err})"), - }) - .build() - })?; - let password = self.identity_password.0.as_deref().unwrap_or(""); - let identity = Identity::from_pkcs12(&db, password)?; - - auth.identity(identity); - } - - Ok(auth) - } -} - -#[derive(Debug, Clone, Copy)] -pub enum PostgresFlavour { - Postgres, - Cockroach, - Unknown, -} - -impl PostgresFlavour { - /// Returns `true` if the postgres flavour is [`Postgres`]. - /// - /// [`Postgres`]: PostgresFlavour::Postgres - fn is_postgres(&self) -> bool { - matches!(self, Self::Postgres) - } - - /// Returns `true` if the postgres flavour is [`Cockroach`]. - /// - /// [`Cockroach`]: PostgresFlavour::Cockroach - fn is_cockroach(&self) -> bool { - matches!(self, Self::Cockroach) - } - - /// Returns `true` if the postgres flavour is [`Unknown`]. - /// - /// [`Unknown`]: PostgresFlavour::Unknown - fn is_unknown(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct PostgresUrl { - url: Url, - query_params: PostgresUrlQueryParams, - flavour: PostgresFlavour, -} - -impl PostgresUrl { - /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { - url, - query_params, - flavour: PostgresFlavour::Unknown, - }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The database host. Taken first from the `host` query parameter, then - /// from the `host` part of the URL. For socket connections, the query - /// parameter must be used. - /// - /// If none of them are set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { - (Some(host), _, _) => host.as_str(), - (None, Some(""), _) => "localhost", - (None, None, _) => "localhost", - (None, Some(host), Some(Host::Ipv6(_))) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (None, Some(host), _) => host, - } - } - - /// Name of the database connected. Defaults to `postgres`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Cow { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => password, - None => self.url.password().unwrap_or("").into(), - } - } - - /// The database port, defaults to `5432`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(5432) - } - - /// The database schema, defaults to `public`. - pub fn schema(&self) -> &str { - self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) - } - - /// Whether the pgbouncer mode is enabled. - pub fn pg_bouncer(&self) -> bool { - self.query_params.pg_bouncer - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// Pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - /// The custom application name - pub fn application_name(&self) -> Option<&str> { - self.query_params.application_name.as_deref() - } - - pub fn channel_binding(&self) -> ChannelBinding { - self.query_params.channel_binding - } - - pub(crate) fn cache(&self) -> LruCache { - if self.query_params.pg_bouncer { - LruCache::new(0) - } else { - LruCache::new(self.query_params.statement_cache_size) - } - } - - pub(crate) fn options(&self) -> Option<&str> { - self.query_params.options.as_deref() - } - - /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. - /// This is used to avoid a network roundtrip at connection to set the search path. - /// - /// The different behaviours are: - /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. - /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. - /// - Unknown: Always add a network roundtrip by setting the search path through a database query. - pub fn set_flavour(&mut self, flavour: PostgresFlavour) { - self.flavour = flavour; - } - - fn parse_query_params(url: &Url) -> Result { - let mut connection_limit = None; - let mut schema = None; - let mut certificate_file = None; - let mut identity_file = None; - let mut identity_password = None; - let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - let mut ssl_mode = SslMode::Prefer; - let mut host = None; - let mut application_name = None; - let mut channel_binding = ChannelBinding::Prefer; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut pg_bouncer = false; - let mut statement_cache_size = 100; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut options = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "pgbouncer" => { - pg_bouncer = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslmode" => { - match v.as_ref() { - "disable" => ssl_mode = SslMode::Disable, - "prefer" => ssl_mode = SslMode::Prefer, - "require" => ssl_mode = SslMode::Require, - _ => { - tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); - } - }; - } - "sslcert" => { - certificate_file = Some(v.to_string()); - } - "sslidentity" => { - identity_file = Some(v.to_string()); - } - "sslpassword" => { - identity_password = Some(v.to_string()); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslaccept" => { - match v.as_ref() { - "strict" => { - ssl_accept_mode = SslAcceptMode::Strict; - } - "accept_invalid_certs" => { - ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - } - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - - ssl_accept_mode = SslAcceptMode::Strict; - } - }; - } - "schema" => { - schema = Some(v.to_string()); - } - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connection_limit = Some(as_int); - } - "host" => { - host = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - connect_timeout = None; - } else { - connect_timeout = Some(Duration::from_secs(as_int)); - } - } - "pool_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - pool_timeout = None; - } else { - pool_timeout = Some(Duration::from_secs(as_int)); - } - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "application_name" => { - application_name = Some(v.to_string()); - } - "channel_binding" => { - match v.as_ref() { - "disable" => channel_binding = ChannelBinding::Disable, - "prefer" => channel_binding = ChannelBinding::Prefer, - "require" => channel_binding = ChannelBinding::Require, - _ => { - tracing::debug!( - message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", - channel_binding = &*v - ); - } - }; - } - "options" => { - options = Some(v.to_string()); - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(PostgresUrlQueryParams { - ssl_params: SslParams { - certificate_file, - identity_file, - ssl_accept_mode, - identity_password: Hidden(identity_password), - }, - connection_limit, - schema, - ssl_mode, - host, - connect_timeout, - pool_timeout, - socket_timeout, - pg_bouncer, - statement_cache_size, - max_connection_lifetime, - max_idle_connection_lifetime, - application_name, - channel_binding, - options, - }) - } - - pub(crate) fn ssl_params(&self) -> &SslParams { - &self.query_params.ssl_params - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - fn set_search_path(&self, config: &mut Config) { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if self.query_params.pg_bouncer { - return; - } - - if let Some(schema) = &self.query_params.schema { - if self.flavour().is_cockroach() && is_safe_identifier(schema) { - config.search_path(CockroachSearchPath(schema).to_string()); - } - - if self.flavour().is_postgres() { - config.search_path(PostgresSearchPath(schema).to_string()); - } - } - } - - pub(crate) fn to_config(&self) -> Config { - let mut config = Config::new(); - - config.user(self.username().borrow()); - config.password(self.password().borrow() as &str); - config.host(self.host()); - config.port(self.port()); - config.dbname(self.dbname()); - config.pgbouncer_mode(self.query_params.pg_bouncer); - - if let Some(options) = self.options() { - config.options(options); - } - - if let Some(application_name) = self.application_name() { - config.application_name(application_name); - } - - if let Some(connect_timeout) = self.query_params.connect_timeout { - config.connect_timeout(connect_timeout); - } - - self.set_search_path(&mut config); - - config.ssl_mode(self.query_params.ssl_mode); - - config.channel_binding(self.query_params.channel_binding); - - config - } - - pub fn flavour(&self) -> PostgresFlavour { - self.flavour - } -} - -#[derive(Debug, Clone)] -pub(crate) struct PostgresUrlQueryParams { - ssl_params: SslParams, - connection_limit: Option, - schema: Option, - ssl_mode: SslMode, - pg_bouncer: bool, - host: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - statement_cache_size: usize, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - application_name: Option, - channel_binding: ChannelBinding, - options: Option, -} - -impl PostgreSql { - /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { - let config = url.to_config(); - - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); - let (client, conn) = super::timeout::connect(url.connect_timeout(), config.connect(tls)).await?; - - tokio::spawn(conn.map(|r| match r { - Ok(_) => (), - Err(e) => { - tracing::error!("Error in PostgreSQL connection: {:?}", e); - } - })); - - // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. - if let Some(schema) = &url.query_params.schema { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if url.query_params.pg_bouncer - || url.flavour().is_unknown() - || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) - { - let session_variables = format!( - r##"{set_search_path}"##, - set_search_path = SetSearchPath(url.query_params.schema.as_deref()) - ); - - client.simple_query(session_variables.as_str()).await?; - } - } - - Ok(Self { - client: PostgresClient(client), - socket_timeout: url.query_params.socket_timeout, - pg_bouncer: url.query_params.pg_bouncer, - statement_cache: Mutex::new(url.cache()), - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying tokio_postgres::Client. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &tokio_postgres::Client { - &self.client.0 - } - - async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let param_types = conversion::params_to_types(params); - let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; - - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } - - fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { - if params.len() > i16::MAX as usize { - // tokio_postgres would return an error here. Let's avoid calling the driver - // and return an error early. - let kind = ErrorKind::QueryInvalidInput(format!( - "too many bind variables in prepared statement, expected maximum of {}, received {}", - i16::MAX, - params.len() - )); - Err(Error::builder(kind).build()) - } else { - Ok(()) - } - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct CockroachSearchPath<'a>(&'a str); - -impl Display for CockroachSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.0) - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct PostgresSearchPath<'a>(&'a str); - -impl Display for PostgresSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("\"")?; - f.write_str(self.0)?; - f.write_str("\"")?; - - Ok(()) - } -} - -// A SetSearchPath statement (Display-impl) for connection initialization. -struct SetSearchPath<'a>(Option<&'a str>); - -impl Display for SetSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(schema) = self.0 { - f.write_str("SET search_path = \"")?; - f.write_str(schema)?; - f.write_str("\";\n")?; - } - - Ok(()) - } -} - -impl_default_TransactionCapable!(PostgreSql); - -#[async_trait] -impl Queryable for PostgreSql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.query_raw(sql.as_str(), ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.execute_raw(sql.as_str(), ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("postgres.raw_cmd", cmd, &[], move || async move { - self.perform_io(self.client.0.simple_query(cmd)).await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT version()"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { - if self.pg_bouncer { - tx.raw_cmd("DEALLOCATE ALL").await - } else { - Ok(()) - } - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::tests::test_api::postgres::CONN_STR; - use crate::tests::test_api::CRDB_CONN_STR; - use crate::{connector::Queryable, error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/psql.sock", url.host()); - } - - #[test] - fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/postgresql", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[test] - fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); - assert_eq!(Some("test"), url.application_name()); - } - - #[test] - fn should_have_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Require, url.channel_binding()); - } - - #[test] - fn should_have_default_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - } - - #[test] - fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); - } - - #[test] - fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("localhost", url.host()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); - - assert_eq!("--cluster=my_cluster", url.options().unwrap()); - } - - #[tokio::test] - async fn test_custom_search_path_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_pg_pgbouncer() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - url.query_pairs_mut().append_pair("pbbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_path("/this_does_not_exist"); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("3D000"), e.original_code()); - assert_eq!( - Some("database \"this_does_not_exist\" does not exist"), - e.original_message() - ); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), - }, - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } - - #[tokio::test] - async fn should_map_tls_errors() { - let mut url = Url::parse(&CONN_STR).expect("parsing url"); - url.set_query(Some("sslmode=require&sslaccept=strict")); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::TlsError { .. } => (), - other => panic!("{:#?}", other), - }, - } - } - - #[tokio::test] - async fn should_map_incorrect_parameters_error() { - let url = Url::parse(&CONN_STR).unwrap(); - let conn = Quaint::new(url.as_str()).await.unwrap(); - - let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::IncorrectNumberOfParameters { expected, actual } => { - assert_eq!(1, *expected); - assert_eq!(2, *actual); - } - other => panic!("{:#?}", other), - }, - } - } - - #[test] - fn test_safe_ident() { - // Safe - assert!(is_safe_identifier("hello")); - assert!(is_safe_identifier("_hello")); - assert!(is_safe_identifier("àbracadabra")); - assert!(is_safe_identifier("h3ll0")); - assert!(is_safe_identifier("héllo")); - assert!(is_safe_identifier("héll0$")); - assert!(is_safe_identifier("héll_0$")); - assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); - - // Not safe - assert!(!is_safe_identifier("")); - assert!(!is_safe_identifier("Hello")); - assert!(!is_safe_identifier("hEllo")); - assert!(!is_safe_identifier("$hello")); - assert!(!is_safe_identifier("hello!")); - assert!(!is_safe_identifier("hello#")); - assert!(!is_safe_identifier("he llo")); - assert!(!is_safe_identifier(" hello")); - assert!(!is_safe_identifier("he-llo")); - assert!(!is_safe_identifier("hÉllo")); - assert!(!is_safe_identifier("1337")); - assert!(!is_safe_identifier("_HELLO")); - assert!(!is_safe_identifier("HELLO")); - assert!(!is_safe_identifier("HELLO$")); - assert!(!is_safe_identifier("ÀBRACADABRA")); - - for ident in RESERVED_KEYWORDS { - assert!(!is_safe_identifier(ident)); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert!(!is_safe_identifier(ident)); - } - } - - #[test] - fn search_path_pgbouncer_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - url.query_pairs_mut().append_pair("pgbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // PGBouncer does not support the `search_path` connection parameter. - // When `pgbouncer=true`, config.search_path should be None, - // And the `search_path` should be set via a db query after connection. - assert_eq!(config.get_search_path(), None); - } - - #[test] - fn search_path_pg_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // Postgres supports setting the search_path via a connection parameter. - assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); - } - - #[test] - fn search_path_crdb_safe_ident_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB supports setting the search_path via a connection parameter if the identifier is safe. - assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); - } - - #[test] - fn search_path_crdb_unsafe_ident_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "HeLLo"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. - assert_eq!(config.get_search_path(), None); - } -} +#[cfg(feature = "postgresql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/postgres/error.rs b/quaint/src/connector/postgres/error.rs index d4e5ec7837fe..ab6ec7b07847 100644 --- a/quaint/src/connector/postgres/error.rs +++ b/quaint/src/connector/postgres/error.rs @@ -1,7 +1,5 @@ use std::fmt::{Display, Formatter}; -use tokio_postgres::error::DbError; - use crate::error::{DatabaseConstraint, Error, ErrorKind, Name}; #[derive(Debug)] @@ -17,7 +15,7 @@ pub struct PostgresError { impl std::error::Error for PostgresError {} impl Display for PostgresError { - // copy of DbError::fmt + // copy of tokio_postgres::error::DbError::fmt fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result { write!(fmt, "{}: {}", self.severity, self.message)?; if let Some(detail) = &self.detail { @@ -30,19 +28,6 @@ impl Display for PostgresError { } } -impl From<&DbError> for PostgresError { - fn from(value: &DbError) -> Self { - PostgresError { - code: value.code().code().to_string(), - severity: value.severity().to_string(), - message: value.message().to_string(), - detail: value.detail().map(ToString::to_string), - column: value.column().map(ToString::to_string), - hint: value.hint().map(ToString::to_string), - } - } -} - impl From for Error { fn from(value: PostgresError) -> Self { match value.code.as_str() { @@ -245,110 +230,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: tokio_postgres::error::Error) -> Error { - if e.is_closed() { - return Error::builder(ErrorKind::ConnectionClosed).build(); - } - - if let Some(db_error) = e.as_db_error() { - return PostgresError::from(db_error).into(); - } - - if let Some(tls_error) = try_extracting_tls_error(&e) { - return tls_error; - } - - // Same for IO errors. - if let Some(io_error) = try_extracting_io_error(&e) { - return io_error; - } - - if let Some(uuid_error) = try_extracting_uuid_error(&e) { - return uuid_error; - } - - let reason = format!("{e}"); - let code = e.code().map(|c| c.code()); - - match reason.as_str() { - "error connecting to server: timed out" => { - let mut builder = Error::builder(ErrorKind::ConnectTimeout); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // sigh... - // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 - "error performing TLS handshake: server does not support TLS" => { - let mut builder = Error::builder(ErrorKind::TlsError { - message: reason.clone(), - }); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // double sigh - _ => { - let code = code.map(|c| c.to_string()); - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } - } - } -} - -fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::UUIDError(format!("{err}"))) - .map(|kind| Error::builder(kind).build()) -} - -fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| err.into()) -} - -fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) - .map(|kind| Error::builder(kind).build()) -} - -impl From for Error { - fn from(e: native_tls::Error) -> Error { - Error::from(&e) - } -} - -impl From<&native_tls::Error> for Error { - fn from(e: &native_tls::Error) -> Error { - let kind = ErrorKind::TlsError { - message: format!("{e}"), - }; - - Error::builder(kind).build() - } -} diff --git a/quaint/src/connector/postgres/conversion.rs b/quaint/src/connector/postgres/native/conversion.rs similarity index 100% rename from quaint/src/connector/postgres/conversion.rs rename to quaint/src/connector/postgres/native/conversion.rs diff --git a/quaint/src/connector/postgres/conversion/decimal.rs b/quaint/src/connector/postgres/native/conversion/decimal.rs similarity index 100% rename from quaint/src/connector/postgres/conversion/decimal.rs rename to quaint/src/connector/postgres/native/conversion/decimal.rs diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs new file mode 100644 index 000000000000..c353e397705c --- /dev/null +++ b/quaint/src/connector/postgres/native/error.rs @@ -0,0 +1,126 @@ +use tokio_postgres::error::DbError; + +use crate::{ + connector::postgres::error::PostgresError, + error::{Error, ErrorKind}, +}; + +impl From<&DbError> for PostgresError { + fn from(value: &DbError) -> Self { + PostgresError { + code: value.code().code().to_string(), + severity: value.severity().to_string(), + message: value.message().to_string(), + detail: value.detail().map(ToString::to_string), + column: value.column().map(ToString::to_string), + hint: value.hint().map(ToString::to_string), + } + } +} + +impl From for Error { + fn from(e: tokio_postgres::error::Error) -> Error { + if e.is_closed() { + return Error::builder(ErrorKind::ConnectionClosed).build(); + } + + if let Some(db_error) = e.as_db_error() { + return PostgresError::from(db_error).into(); + } + + if let Some(tls_error) = try_extracting_tls_error(&e) { + return tls_error; + } + + // Same for IO errors. + if let Some(io_error) = try_extracting_io_error(&e) { + return io_error; + } + + if let Some(uuid_error) = try_extracting_uuid_error(&e) { + return uuid_error; + } + + let reason = format!("{e}"); + let code = e.code().map(|c| c.code()); + + match reason.as_str() { + "error connecting to server: timed out" => { + let mut builder = Error::builder(ErrorKind::ConnectTimeout); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // sigh... + // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 + "error performing TLS handshake: server does not support TLS" => { + let mut builder = Error::builder(ErrorKind::TlsError { + message: reason.clone(), + }); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // double sigh + _ => { + let code = code.map(|c| c.to_string()); + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } + } + } +} + +fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::UUIDError(format!("{err}"))) + .map(|kind| Error::builder(kind).build()) +} + +fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| err.into()) +} + +fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) + .map(|kind| Error::builder(kind).build()) +} + +impl From for Error { + fn from(e: native_tls::Error) -> Error { + Error::from(&e) + } +} + +impl From<&native_tls::Error> for Error { + fn from(e: &native_tls::Error) -> Error { + let kind = ErrorKind::TlsError { + message: format!("{e}"), + }; + + Error::builder(kind).build() + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs new file mode 100644 index 000000000000..30f34e7002be --- /dev/null +++ b/quaint/src/connector/postgres/native/mod.rs @@ -0,0 +1,972 @@ +//! Definitions for the Postgres connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `postgresql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::postgres::url::PostgresUrl; +use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; +use crate::connector::{timeout, IsolationLevel, Transaction}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::{future::FutureExt, lock::Mutex}; +use lru_cache::LruCache; +use native_tls::{Certificate, Identity, TlsConnector}; +use postgres_native_tls::MakeTlsConnector; +use std::{ + borrow::Borrow, + fmt::{Debug, Display}, + fs, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; + +/// The underlying postgres driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tokio_postgres; + +struct PostgresClient(Client); + +impl Debug for PostgresClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PostgresClient") + } +} + +/// A connector interface for the PostgreSQL database. +#[derive(Debug)] +pub struct PostgreSql { + client: PostgresClient, + pg_bouncer: bool, + socket_timeout: Option, + statement_cache: Mutex>, + is_healthy: AtomicBool, +} + +#[derive(Debug)] +struct SslAuth { + certificate: Hidden>, + identity: Hidden>, + ssl_accept_mode: SslAcceptMode, +} + +impl Default for SslAuth { + fn default() -> Self { + Self { + certificate: Hidden(None), + identity: Hidden(None), + ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, + } + } +} + +impl SslAuth { + fn certificate(&mut self, certificate: Certificate) -> &mut Self { + self.certificate = Hidden(Some(certificate)); + self + } + + fn identity(&mut self, identity: Identity) -> &mut Self { + self.identity = Hidden(Some(identity)); + self + } + + fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { + self.ssl_accept_mode = mode; + self + } +} + +impl SslParams { + async fn into_auth(self) -> crate::Result { + let mut auth = SslAuth::default(); + auth.accept_mode(self.ssl_accept_mode); + + if let Some(ref cert_file) = self.certificate_file { + let cert = fs::read(cert_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("cert file not found ({err})"), + }) + .build() + })?; + + auth.certificate(Certificate::from_pem(&cert)?); + } + + if let Some(ref identity_file) = self.identity_file { + let db = fs::read(identity_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("identity file not found ({err})"), + }) + .build() + })?; + let password = self.identity_password.0.as_deref().unwrap_or(""); + let identity = Identity::from_pkcs12(&db, password)?; + + auth.identity(identity); + } + + Ok(auth) + } +} + +impl PostgresUrl { + pub(crate) fn cache(&self) -> LruCache { + if self.query_params.pg_bouncer { + LruCache::new(0) + } else { + LruCache::new(self.query_params.statement_cache_size) + } + } + + pub fn channel_binding(&self) -> ChannelBinding { + self.query_params.channel_binding + } + + /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + fn set_search_path(&self, config: &mut Config) { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if self.query_params.pg_bouncer { + return; + } + + if let Some(schema) = &self.query_params.schema { + if self.flavour().is_cockroach() && is_safe_identifier(schema) { + config.search_path(CockroachSearchPath(schema).to_string()); + } + + if self.flavour().is_postgres() { + config.search_path(PostgresSearchPath(schema).to_string()); + } + } + } + + pub(crate) fn to_config(&self) -> Config { + let mut config = Config::new(); + + config.user(self.username().borrow()); + config.password(self.password().borrow() as &str); + config.host(self.host()); + config.port(self.port()); + config.dbname(self.dbname()); + config.pgbouncer_mode(self.query_params.pg_bouncer); + + if let Some(options) = self.options() { + config.options(options); + } + + if let Some(application_name) = self.application_name() { + config.application_name(application_name); + } + + if let Some(connect_timeout) = self.query_params.connect_timeout { + config.connect_timeout(connect_timeout); + } + + self.set_search_path(&mut config); + + config.ssl_mode(self.query_params.ssl_mode); + + config.channel_binding(self.query_params.channel_binding); + + config + } +} + +impl PostgreSql { + /// Create a new connection to the database. + pub async fn new(url: PostgresUrl) -> crate::Result { + let config = url.to_config(); + + let mut tls_builder = TlsConnector::builder(); + + { + let ssl_params = url.ssl_params(); + let auth = ssl_params.to_owned().into_auth().await?; + + if let Some(certificate) = auth.certificate.0 { + tls_builder.add_root_certificate(certificate); + } + + tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); + + if let Some(identity) = auth.identity.0 { + tls_builder.identity(identity); + } + } + + let tls = MakeTlsConnector::new(tls_builder.build()?); + let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; + + tokio::spawn(conn.map(|r| match r { + Ok(_) => (), + Err(e) => { + tracing::error!("Error in PostgreSQL connection: {:?}", e); + } + })); + + // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. + if let Some(schema) = &url.query_params.schema { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if url.query_params.pg_bouncer + || url.flavour().is_unknown() + || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) + { + let session_variables = format!( + r##"{set_search_path}"##, + set_search_path = SetSearchPath(url.query_params.schema.as_deref()) + ); + + client.simple_query(session_variables.as_str()).await?; + } + } + + Ok(Self { + client: PostgresClient(client), + socket_timeout: url.query_params.socket_timeout, + pg_bouncer: url.query_params.pg_bouncer, + statement_cache: Mutex::new(url.cache()), + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying tokio_postgres::Client. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &tokio_postgres::Client { + &self.client.0 + } + + async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let param_types = conversion::params_to_types(params); + let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; + + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } + + fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { + if params.len() > i16::MAX as usize { + // tokio_postgres would return an error here. Let's avoid calling the driver + // and return an error early. + let kind = ErrorKind::QueryInvalidInput(format!( + "too many bind variables in prepared statement, expected maximum of {}, received {}", + i16::MAX, + params.len() + )); + Err(Error::builder(kind).build()) + } else { + Ok(()) + } + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +impl_default_TransactionCapable!(PostgreSql); + +#[async_trait] +impl Queryable for PostgreSql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.query_raw(sql.as_str(), ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.execute_raw(sql.as_str(), ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("postgres.raw_cmd", cmd, &[], move || async move { + self.perform_io(self.client.0.simple_query(cmd)).await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT version()"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { + if self.pg_bouncer { + tx.raw_cmd("DEALLOCATE ALL").await + } else { + Ok(()) + } + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::connector::Queryable; + use crate::tests::test_api::postgres::CONN_STR; + use crate::tests::test_api::CRDB_CONN_STR; + use url::Url; + + #[tokio::test] + async fn test_custom_search_path_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_pg_pgbouncer() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + url.query_pairs_mut().append_pair("pbbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[test] + fn test_safe_ident() { + // Safe + assert!(is_safe_identifier("hello")); + assert!(is_safe_identifier("_hello")); + assert!(is_safe_identifier("àbracadabra")); + assert!(is_safe_identifier("h3ll0")); + assert!(is_safe_identifier("héllo")); + assert!(is_safe_identifier("héll0$")); + assert!(is_safe_identifier("héll_0$")); + assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); + + // Not safe + assert!(!is_safe_identifier("")); + assert!(!is_safe_identifier("Hello")); + assert!(!is_safe_identifier("hEllo")); + assert!(!is_safe_identifier("$hello")); + assert!(!is_safe_identifier("hello!")); + assert!(!is_safe_identifier("hello#")); + assert!(!is_safe_identifier("he llo")); + assert!(!is_safe_identifier(" hello")); + assert!(!is_safe_identifier("he-llo")); + assert!(!is_safe_identifier("hÉllo")); + assert!(!is_safe_identifier("1337")); + assert!(!is_safe_identifier("_HELLO")); + assert!(!is_safe_identifier("HELLO")); + assert!(!is_safe_identifier("HELLO$")); + assert!(!is_safe_identifier("ÀBRACADABRA")); + + for ident in RESERVED_KEYWORDS { + assert!(!is_safe_identifier(ident)); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert!(!is_safe_identifier(ident)); + } + } +} diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs new file mode 100644 index 000000000000..f0b60d88a848 --- /dev/null +++ b/quaint/src/connector/postgres/url.rs @@ -0,0 +1,695 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use std::{ + borrow::Cow, + fmt::{Debug, Display}, + time::Duration, +}; + +use percent_encoding::percent_decode; +use url::{Host, Url}; + +use crate::error::{Error, ErrorKind}; + +#[cfg(feature = "postgresql-native")] +use tokio_postgres::config::{ChannelBinding, SslMode}; + +#[derive(Clone)] +pub(crate) struct Hidden(pub(crate) T); + +impl Debug for Hidden { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SslAcceptMode { + Strict, + AcceptInvalidCerts, +} + +#[derive(Debug, Clone)] +pub struct SslParams { + pub(crate) certificate_file: Option, + pub(crate) identity_file: Option, + pub(crate) identity_password: Hidden>, + pub(crate) ssl_accept_mode: SslAcceptMode, +} + +#[derive(Debug, Clone, Copy)] +pub enum PostgresFlavour { + Postgres, + Cockroach, + Unknown, +} + +impl PostgresFlavour { + /// Returns `true` if the postgres flavour is [`Postgres`]. + /// + /// [`Postgres`]: PostgresFlavour::Postgres + pub(crate) fn is_postgres(&self) -> bool { + matches!(self, Self::Postgres) + } + + /// Returns `true` if the postgres flavour is [`Cockroach`]. + /// + /// [`Cockroach`]: PostgresFlavour::Cockroach + pub(crate) fn is_cockroach(&self) -> bool { + matches!(self, Self::Cockroach) + } + + /// Returns `true` if the postgres flavour is [`Unknown`]. + /// + /// [`Unknown`]: PostgresFlavour::Unknown + pub(crate) fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } +} + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct PostgresUrl { + pub(crate) url: Url, + pub(crate) query_params: PostgresUrlQueryParams, + pub(crate) flavour: PostgresFlavour, +} + +pub(crate) const DEFAULT_SCHEMA: &str = "public"; + +impl PostgresUrl { + /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { + url, + query_params, + flavour: PostgresFlavour::Unknown, + }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The database host. Taken first from the `host` query parameter, then + /// from the `host` part of the URL. For socket connections, the query + /// parameter must be used. + /// + /// If none of them are set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { + (Some(host), _, _) => host.as_str(), + (None, Some(""), _) => "localhost", + (None, None, _) => "localhost", + (None, Some(host), Some(Host::Ipv6(_))) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (None, Some(host), _) => host, + } + } + + /// Name of the database connected. Defaults to `postgres`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Cow { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => password, + None => self.url.password().unwrap_or("").into(), + } + } + + /// The database port, defaults to `5432`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(5432) + } + + /// The database schema, defaults to `public`. + pub fn schema(&self) -> &str { + self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) + } + + /// Whether the pgbouncer mode is enabled. + pub fn pg_bouncer(&self) -> bool { + self.query_params.pg_bouncer + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// Pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + /// The custom application name + pub fn application_name(&self) -> Option<&str> { + self.query_params.application_name.as_deref() + } + + pub(crate) fn options(&self) -> Option<&str> { + self.query_params.options.as_deref() + } + + /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. + /// This is used to avoid a network roundtrip at connection to set the search path. + /// + /// The different behaviours are: + /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. + /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. + /// - Unknown: Always add a network roundtrip by setting the search path through a database query. + pub fn set_flavour(&mut self, flavour: PostgresFlavour) { + self.flavour = flavour; + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "postgresql-native")] + let mut ssl_mode = SslMode::Prefer; + #[cfg(feature = "postgresql-native")] + let mut channel_binding = ChannelBinding::Prefer; + + let mut connection_limit = None; + let mut schema = None; + let mut certificate_file = None; + let mut identity_file = None; + let mut identity_password = None; + let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + let mut host = None; + let mut application_name = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut pg_bouncer = false; + let mut statement_cache_size = 100; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut options = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "pgbouncer" => { + pg_bouncer = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + #[cfg(feature = "postgresql-native")] + "sslmode" => { + match v.as_ref() { + "disable" => ssl_mode = SslMode::Disable, + "prefer" => ssl_mode = SslMode::Prefer, + "require" => ssl_mode = SslMode::Require, + _ => { + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + certificate_file = Some(v.to_string()); + } + "sslidentity" => { + identity_file = Some(v.to_string()); + } + "sslpassword" => { + identity_password = Some(v.to_string()); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslaccept" => { + match v.as_ref() { + "strict" => { + ssl_accept_mode = SslAcceptMode::Strict; + } + "accept_invalid_certs" => { + ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + } + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `strict`", + mode = &*v + ); + + ssl_accept_mode = SslAcceptMode::Strict; + } + }; + } + "schema" => { + schema = Some(v.to_string()); + } + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connection_limit = Some(as_int); + } + "host" => { + host = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + connect_timeout = None; + } else { + connect_timeout = Some(Duration::from_secs(as_int)); + } + } + "pool_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + pool_timeout = None; + } else { + pool_timeout = Some(Duration::from_secs(as_int)); + } + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "application_name" => { + application_name = Some(v.to_string()); + } + #[cfg(feature = "postgresql-native")] + "channel_binding" => { + match v.as_ref() { + "disable" => channel_binding = ChannelBinding::Disable, + "prefer" => channel_binding = ChannelBinding::Prefer, + "require" => channel_binding = ChannelBinding::Require, + _ => { + tracing::debug!( + message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", + channel_binding = &*v + ); + } + }; + } + "options" => { + options = Some(v.to_string()); + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(PostgresUrlQueryParams { + ssl_params: SslParams { + certificate_file, + identity_file, + ssl_accept_mode, + identity_password: Hidden(identity_password), + }, + connection_limit, + schema, + host, + connect_timeout, + pool_timeout, + socket_timeout, + pg_bouncer, + statement_cache_size, + max_connection_lifetime, + max_idle_connection_lifetime, + application_name, + options, + #[cfg(feature = "postgresql-native")] + channel_binding, + #[cfg(feature = "postgresql-native")] + ssl_mode, + }) + } + + pub(crate) fn ssl_params(&self) -> &SslParams { + &self.query_params.ssl_params + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub fn flavour(&self) -> PostgresFlavour { + self.flavour + } +} + +#[derive(Debug, Clone)] +pub(crate) struct PostgresUrlQueryParams { + pub(crate) ssl_params: SslParams, + pub(crate) connection_limit: Option, + pub(crate) schema: Option, + pub(crate) pg_bouncer: bool, + pub(crate) host: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) statement_cache_size: usize, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) application_name: Option, + pub(crate) options: Option, + + #[cfg(feature = "postgresql-native")] + pub(crate) channel_binding: ChannelBinding, + + #[cfg(feature = "postgresql-native")] + pub(crate) ssl_mode: SslMode, +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::Value; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::tests::test_api::postgres::CONN_STR; + use crate::{connector::Queryable, error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/psql.sock", url.host()); + } + + #[test] + fn should_parse_escaped_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/postgresql", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[test] + fn should_have_application_name() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + assert_eq!(Some("test"), url.application_name()); + } + + #[test] + fn should_have_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Require, url.channel_binding()); + } + + #[test] + fn should_have_default_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + } + + #[test] + fn should_not_enable_caching_with_pgbouncer() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + assert_eq!(0, url.cache().capacity()); + } + + #[test] + fn should_parse_default_host() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("localhost", url.host()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_handle_options_field() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); + + assert_eq!("--cluster=my_cluster", url.options().unwrap()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_path("/this_does_not_exist"); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("3D000"), e.original_code()); + assert_eq!( + Some("database \"this_does_not_exist\" does not exist"), + e.original_message() + ); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), + }, + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } + + #[tokio::test] + async fn should_map_tls_errors() { + let mut url = Url::parse(&CONN_STR).expect("parsing url"); + url.set_query(Some("sslmode=require&sslaccept=strict")); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::TlsError { .. } => (), + other => panic!("{:#?}", other), + }, + } + } + + #[tokio::test] + async fn should_map_incorrect_parameters_error() { + let url = Url::parse(&CONN_STR).unwrap(); + let conn = Quaint::new(url.as_str()).await.unwrap(); + + let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::IncorrectNumberOfParameters { expected, actual } => { + assert_eq!(1, *expected); + assert_eq!(2, *actual); + } + other => panic!("{:#?}", other), + }, + } + } + + #[test] + fn search_path_pgbouncer_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + url.query_pairs_mut().append_pair("pgbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // PGBouncer does not support the `search_path` connection parameter. + // When `pgbouncer=true`, config.search_path should be None, + // And the `search_path` should be set via a db query after connection. + assert_eq!(config.get_search_path(), None); + } + + #[test] + fn search_path_pg_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // Postgres supports setting the search_path via a connection parameter. + assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); + } + + #[test] + fn search_path_crdb_safe_ident_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB supports setting the search_path via a connection parameter if the identifier is safe. + assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); + } + + #[test] + fn search_path_crdb_unsafe_ident_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "HeLLo"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. + assert_eq!(config.get_search_path(), None); + } +} diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 3a1ef72b4883..c59c947b8dc1 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,353 +1,11 @@ -mod conversion; -mod error; +//! Wasm-compatible definitions for the SQLite connector. +//! This module is only available with the `sqlite` feature. +pub(crate) mod error; +mod ffi; +pub(crate) mod params; pub use error::SqliteError; +pub use params::*; -pub use rusqlite::{params_from_iter, version as sqlite_version}; - -use super::IsolationLevel; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use std::{convert::TryFrom, path::Path, time::Duration}; -use tokio::sync::Mutex; - -pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; - -/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use rusqlite; - -/// A connector interface for the SQLite database -pub struct Sqlite { - pub(crate) client: Mutex, -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug)] -pub struct SqliteParams { - pub connection_limit: Option, - /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can - /// only be done with UTF-8 paths. - pub file_path: String, - pub db_name: String, - pub socket_timeout: Option, - pub max_connection_lifetime: Option, - pub max_idle_connection_lifetime: Option, -} - -impl TryFrom<&str> for SqliteParams { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let path = if path.starts_with("file:") { - path.trim_start_matches("file:") - } else { - path.trim_start_matches("sqlite:") - }; - - let path_parts: Vec<&str> = path.split('?').collect(); - let path_str = path_parts[0]; - let path = Path::new(path_str); - - if path.is_dir() { - Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) - } else { - let mut connection_limit = None; - let mut socket_timeout = None; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = None; - - if path_parts.len() > 1 { - let params = path_parts.last().unwrap().split('&').map(|kv| { - let splitted: Vec<&str> = kv.split('=').collect(); - (splitted[0], splitted[1]) - }); - - for (k, v) in params { - match k { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - socket_timeout = Some(Duration::from_secs(as_int)); - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = k); - } - }; - } - } - - Ok(Self { - connection_limit, - file_path: path_str.to_owned(), - db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), - socket_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } - } -} - -impl TryFrom<&str> for Sqlite { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let params = SqliteParams::try_from(path)?; - let file_path = params.file_path; - - let conn = rusqlite::Connection::open(file_path.as_str())?; - - if let Some(timeout) = params.socket_timeout { - conn.busy_timeout(timeout)?; - }; - - let client = Mutex::new(conn); - - Ok(Sqlite { client }) - } -} - -impl Sqlite { - pub fn new(file_path: &str) -> crate::Result { - Self::try_from(file_path) - } - - /// Open a new SQLite database in memory. - pub fn new_in_memory() -> crate::Result { - let client = rusqlite::Connection::open_in_memory()?; - - Ok(Sqlite { - client: Mutex::new(client), - }) - } - - /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo - /// feature. This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn connection(&self) -> &Mutex { - &self.client - } -} - -impl_default_TransactionCapable!(Sqlite); - -#[async_trait] -impl Queryable for Sqlite { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - - let mut stmt = client.prepare_cached(sql)?; - - let mut rows = stmt.query(params_from_iter(params.iter()))?; - let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); - - while let Some(row) = rows.next()? { - result.rows.push(row.get_result_row()?); - } - - result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - let mut stmt = client.prepare_cached(sql)?; - let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; - - Ok(res) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { - let client = self.client.lock().await; - client.execute_batch(cmd)?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - Ok(Some(rusqlite::version().into())) - } - - fn is_healthy(&self) -> bool { - true - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - // SQLite is always "serializable", other modes involve pragmas - // and shared cache mode, which is out of scope for now and should be implemented - // as part of a separate effort. - if !matches!(isolation_level, IsolationLevel::Serializable) { - let kind = ErrorKind::invalid_isolation_level(&isolation_level); - return Err(Error::builder(kind).build()); - } - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - ast::*, - connector::Queryable, - error::{ErrorKind, Name}, - }; - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { - let path = "file:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { - let path = "sqlite:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { - let path = "dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[tokio::test] - async fn unknown_table_should_give_a_good_error() { - let conn = Sqlite::try_from("file:db/test.db").unwrap(); - let select = Select::from_table("not_there"); - - let err = conn.select(select).await.unwrap_err(); - - match err.kind() { - ErrorKind::TableDoesNotExist { table } => { - assert_eq!(&Name::available("not_there"), table); - } - e => panic!("Expected error TableDoesNotExist, got {:?}", e), - } - } - - #[tokio::test] - async fn in_memory_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); - - // Check that we do get a separate, new database. - let other_conn = Sqlite::new_in_memory().unwrap(); - - let err = other_conn.select(select).await.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); - } - - #[tokio::test] - async fn quoting_in_returning_in_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - let insert: Insert = Insert::from(insert).returning(["txt space"]); - - let result = conn.insert(insert).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - } -} +#[cfg(feature = "sqlite-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/error.rs b/quaint/src/connector/sqlite/error.rs index c10b335cb3c0..2c6ff11350fd 100644 --- a/quaint/src/connector/sqlite/error.rs +++ b/quaint/src/connector/sqlite/error.rs @@ -1,8 +1,4 @@ -use std::fmt; - use crate::error::*; -use rusqlite::ffi; -use rusqlite::types::FromSqlError; #[derive(Debug)] pub struct SqliteError { @@ -10,14 +6,10 @@ pub struct SqliteError { pub message: Option, } -impl fmt::Display for SqliteError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "Error code {}: {}", - self.extended_code, - ffi::code_to_str(self.extended_code) - ) +#[cfg(not(feature = "sqlite-native"))] +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error code {}", self.extended_code) } } @@ -37,7 +29,7 @@ impl From for Error { fn from(error: SqliteError) -> Self { match error { SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY, + extended_code: super::ffi::SQLITE_CONSTRAINT_UNIQUE | super::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, message: Some(description), } => { let constraint = description @@ -58,7 +50,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_NOTNULL, + extended_code: super::ffi::SQLITE_CONSTRAINT_NOTNULL, message: Some(description), } => { let constraint = description @@ -79,7 +71,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_FOREIGNKEY | ffi::SQLITE_CONSTRAINT_TRIGGER, + extended_code: super::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | super::ffi::SQLITE_CONSTRAINT_TRIGGER, message: Some(description), } => { let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { @@ -92,7 +84,7 @@ impl From for Error { builder.build() } - SqliteError { extended_code, message } if error.primary_code() == ffi::SQLITE_BUSY => { + SqliteError { extended_code, message } if error.primary_code() == super::ffi::SQLITE_BUSY => { let mut builder = Error::builder(ErrorKind::SocketTimeout); builder.set_original_code(format!("{extended_code}")); @@ -152,55 +144,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: rusqlite::Error) -> Error { - match e { - rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { - Ok(error) => *error, - Err(error) => { - let mut builder = Error::builder(ErrorKind::QueryError(error)); - - builder.set_original_message("Could not interpret parameters in an SQLite query."); - - builder.build() - } - }, - rusqlite::Error::InvalidQuery => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - builder.set_original_message( - "Could not interpret the query or its parameters. Check the syntax and parameter types.", - ); - - builder.build() - } - rusqlite::Error::ExecuteReturnedResults => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - builder.set_original_message("Execute returned results, which is not allowed in SQLite."); - - builder.build() - } - - rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), - - rusqlite::Error::SqliteFailure(ffi::Error { code: _, extended_code }, message) => { - SqliteError::new(extended_code, message).into() - } - - rusqlite::Error::SqlInputError { - error: ffi::Error { extended_code, .. }, - msg, - .. - } => SqliteError::new(extended_code, Some(msg)).into(), - - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} - -impl From for Error { - fn from(e: FromSqlError) -> Error { - Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() - } -} diff --git a/quaint/src/connector/sqlite/ffi.rs b/quaint/src/connector/sqlite/ffi.rs new file mode 100644 index 000000000000..c510a459be81 --- /dev/null +++ b/quaint/src/connector/sqlite/ffi.rs @@ -0,0 +1,8 @@ +//! Here, we export only the constants we need to avoid pulling in `rusqlite::ffi::*`, in the sibling `error.rs` file, +//! which would break Wasm compilation. +pub const SQLITE_BUSY: i32 = 5; +pub const SQLITE_CONSTRAINT_FOREIGNKEY: i32 = 787; +pub const SQLITE_CONSTRAINT_NOTNULL: i32 = 1299; +pub const SQLITE_CONSTRAINT_PRIMARYKEY: i32 = 1555; +pub const SQLITE_CONSTRAINT_TRIGGER: i32 = 1811; +pub const SQLITE_CONSTRAINT_UNIQUE: i32 = 2067; diff --git a/quaint/src/connector/sqlite/conversion.rs b/quaint/src/connector/sqlite/native/conversion.rs similarity index 100% rename from quaint/src/connector/sqlite/conversion.rs rename to quaint/src/connector/sqlite/native/conversion.rs diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs new file mode 100644 index 000000000000..51b2417ed821 --- /dev/null +++ b/quaint/src/connector/sqlite/native/error.rs @@ -0,0 +1,66 @@ +use crate::connector::sqlite::error::SqliteError; + +use crate::error::*; + +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Error code {}: {}", + self.extended_code, + rusqlite::ffi::code_to_str(self.extended_code) + ) + } +} + +impl From for Error { + fn from(e: rusqlite::Error) -> Error { + match e { + rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { + Ok(error) => *error, + Err(error) => { + let mut builder = Error::builder(ErrorKind::QueryError(error)); + + builder.set_original_message("Could not interpret parameters in an SQLite query."); + + builder.build() + } + }, + rusqlite::Error::InvalidQuery => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + builder.set_original_message( + "Could not interpret the query or its parameters. Check the syntax and parameter types.", + ); + + builder.build() + } + rusqlite::Error::ExecuteReturnedResults => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + builder.set_original_message("Execute returned results, which is not allowed in SQLite."); + + builder.build() + } + + rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), + + rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code: _, extended_code }, message) => { + SqliteError::new(extended_code, message).into() + } + + rusqlite::Error::SqlInputError { + error: rusqlite::ffi::Error { extended_code, .. }, + msg, + .. + } => SqliteError::new(extended_code, Some(msg)).into(), + + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} + +impl From for Error { + fn from(e: rusqlite::types::FromSqlError) -> Error { + Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() + } +} diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs new file mode 100644 index 000000000000..3bf0c46a7db5 --- /dev/null +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -0,0 +1,234 @@ +//! Definitions for the SQLite connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `sqlite-native` feature. +mod conversion; +mod error; + +use crate::connector::sqlite::params::SqliteParams; +use crate::connector::IsolationLevel; + +pub use rusqlite::{params_from_iter, version as sqlite_version}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use std::convert::TryFrom; +use tokio::sync::Mutex; + +/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use rusqlite; + +/// A connector interface for the SQLite database +pub struct Sqlite { + pub(crate) client: Mutex, +} + +impl TryFrom<&str> for Sqlite { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let params = SqliteParams::try_from(path)?; + let file_path = params.file_path; + + let conn = rusqlite::Connection::open(file_path.as_str())?; + + if let Some(timeout) = params.socket_timeout { + conn.busy_timeout(timeout)?; + }; + + let client = Mutex::new(conn); + + Ok(Sqlite { client }) + } +} + +impl Sqlite { + pub fn new(file_path: &str) -> crate::Result { + Self::try_from(file_path) + } + + /// Open a new SQLite database in memory. + pub fn new_in_memory() -> crate::Result { + let client = rusqlite::Connection::open_in_memory()?; + + Ok(Sqlite { + client: Mutex::new(client), + }) + } + + /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo + /// feature. This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn connection(&self) -> &Mutex { + &self.client + } +} + +impl_default_TransactionCapable!(Sqlite); + +#[async_trait] +impl Queryable for Sqlite { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + + let mut stmt = client.prepare_cached(sql)?; + + let mut rows = stmt.query(params_from_iter(params.iter()))?; + let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); + + while let Some(row) = rows.next()? { + result.rows.push(row.get_result_row()?); + } + + result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + let mut stmt = client.prepare_cached(sql)?; + let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; + + Ok(res) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { + let client = self.client.lock().await; + client.execute_batch(cmd)?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + Ok(Some(rusqlite::version().into())) + } + + fn is_healthy(&self) -> bool { + true + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + // SQLite is always "serializable", other modes involve pragmas + // and shared cache mode, which is out of scope for now and should be implemented + // as part of a separate effort. + if !matches!(isolation_level, IsolationLevel::Serializable) { + let kind = ErrorKind::invalid_isolation_level(&isolation_level); + return Err(Error::builder(kind).build()); + } + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ast::*, + connector::Queryable, + error::{ErrorKind, Name}, + }; + + #[tokio::test] + async fn unknown_table_should_give_a_good_error() { + let conn = Sqlite::try_from("file:db/test.db").unwrap(); + let select = Select::from_table("not_there"); + + let err = conn.select(select).await.unwrap_err(); + + match err.kind() { + ErrorKind::TableDoesNotExist { table } => { + assert_eq!(&Name::available("not_there"), table); + } + e => panic!("Expected error TableDoesNotExist, got {:?}", e), + } + } + + #[tokio::test] + async fn in_memory_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); + + // Check that we do get a separate, new database. + let other_conn = Sqlite::new_in_memory().unwrap(); + + let err = other_conn.select(select).await.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); + } + + #[tokio::test] + async fn quoting_in_returning_in_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + let insert: Insert = Insert::from(insert).returning(["txt space"]); + + let result = conn.insert(insert).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + } +} diff --git a/quaint/src/connector/sqlite/params.rs b/quaint/src/connector/sqlite/params.rs new file mode 100644 index 000000000000..f024aa97a694 --- /dev/null +++ b/quaint/src/connector/sqlite/params.rs @@ -0,0 +1,131 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::error::{Error, ErrorKind}; +use std::{convert::TryFrom, path::Path, time::Duration}; + +pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug)] +pub struct SqliteParams { + pub connection_limit: Option, + /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can + /// only be done with UTF-8 paths. + pub file_path: String, + pub db_name: String, + pub socket_timeout: Option, + pub max_connection_lifetime: Option, + pub max_idle_connection_lifetime: Option, +} + +impl TryFrom<&str> for SqliteParams { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let path = if path.starts_with("file:") { + path.trim_start_matches("file:") + } else { + path.trim_start_matches("sqlite:") + }; + + let path_parts: Vec<&str> = path.split('?').collect(); + let path_str = path_parts[0]; + let path = Path::new(path_str); + + if path.is_dir() { + Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) + } else { + let mut connection_limit = None; + let mut socket_timeout = None; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = None; + + if path_parts.len() > 1 { + let params = path_parts.last().unwrap().split('&').map(|kv| { + let splitted: Vec<&str> = kv.split('=').collect(); + (splitted[0], splitted[1]) + }); + + for (k, v) in params { + match k { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + socket_timeout = Some(Duration::from_secs(as_int)); + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = k); + } + }; + } + } + + Ok(Self { + connection_limit, + file_path: path_str.to_owned(), + db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), + socket_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { + let path = "file:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { + let path = "sqlite:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { + let path = "dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } +} diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 705bb6b37ee0..a77513876726 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -282,7 +282,7 @@ pub enum ErrorKind { } impl ErrorKind { - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] pub(crate) fn value_out_of_range(msg: impl Into) -> Self { Self::ValueOutOfRange { message: msg.into() } } diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c0aa8c93b75d..73441b7609ba 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "mssql")] +#[cfg(feature = "mssql-native")] use crate::connector::MssqlUrl; -#[cfg(feature = "mysql")] +#[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; -#[cfg(feature = "postgresql")] +#[cfg(feature = "postgresql-native")] use crate::connector::PostgresUrl; use crate::{ ast, @@ -97,7 +97,7 @@ impl Manager for QuaintManager { async fn connect(&self) -> crate::Result { let conn = match self { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] QuaintManager::Sqlite { url, .. } => { use crate::connector::Sqlite; @@ -106,19 +106,19 @@ impl Manager for QuaintManager { Ok(Box::new(conn) as Self::Connection) } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] QuaintManager::Mysql { url } => { use crate::connector::Mysql; Ok(Box::new(Mysql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] QuaintManager::Postgres { url } => { use crate::connector::PostgreSql; Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] QuaintManager::Mssql { url } => { use crate::connector::Mssql; Ok(Box::new(Mssql::new(url.clone()).await?) as Self::Connection) @@ -146,7 +146,7 @@ mod tests { use crate::pooled::Quaint; #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] async fn mysql_default_connection_limit() { let conn_string = std::env::var("TEST_MYSQL").expect("TEST_MYSQL connection string not set."); @@ -156,7 +156,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] async fn mysql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -169,7 +169,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] async fn psql_default_connection_limit() { let conn_string = std::env::var("TEST_PSQL").expect("TEST_PSQL connection string not set."); @@ -179,7 +179,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] async fn psql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -192,7 +192,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] async fn mssql_default_connection_limit() { let conn_string = std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set."); @@ -202,7 +202,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] async fn mssql_custom_connection_limit() { let conn_string = format!( "{};connectionLimit=10", @@ -215,7 +215,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] async fn test_default_connection_limit() { let conn_string = "file:db/test.db".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); @@ -224,7 +224,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] async fn test_custom_connection_limit() { let conn_string = "file:db/test.db?connection_limit=10".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 82042f58010b..1a4dbdf52a61 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -1,7 +1,5 @@ //! A single connection abstraction to a SQL database. -#[cfg(feature = "sqlite")] -use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; use crate::{ ast, connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, @@ -9,7 +7,7 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; -#[cfg(feature = "sqlite")] +#[cfg(feature = "sqlite-native")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -127,30 +125,31 @@ impl Quaint { /// - `isolationLevel` the transaction isolation level. Possible values: /// `READ UNCOMMITTED`, `READ COMMITTED`, `REPEATABLE READ`, `SNAPSHOT`, /// `SERIALIZABLE`. + #[cfg_attr(target_arch = "wasm32", allow(unused_variables))] #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] s if s.starts_with("file") => { let params = connector::SqliteParams::try_from(s)?; let sqlite = connector::Sqlite::new(¶ms.file_path)?; Arc::new(sqlite) as Arc } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] s if s.starts_with("mysql") => { let url = connector::MysqlUrl::new(url::Url::parse(s)?)?; let mysql = connector::Mysql::new(url).await?; Arc::new(mysql) as Arc } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] s if s.starts_with("jdbc:sqlserver") | s.starts_with("sqlserver") => { let url = connector::MssqlUrl::new(s)?; let psql = connector::Mssql::new(url).await?; @@ -166,9 +165,11 @@ impl Quaint { Ok(Self { inner, connection_info }) } - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { + use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; + Ok(Quaint { inner: Arc::new(connector::Sqlite::new_in_memory()?), connection_info: Arc::new(ConnectionInfo::InMemorySqlite { diff --git a/quaint/src/visitor/postgres.rs b/quaint/src/visitor/postgres.rs index 35921637c051..648b3f0dc1ec 100644 --- a/quaint/src/visitor/postgres.rs +++ b/quaint/src/visitor/postgres.rs @@ -78,33 +78,27 @@ impl<'a> Visitor<'a> for Postgres<'a> { variants: Vec>, name: Option>, ) -> visitor::Result { - let len = variants.len(); - // Since enums are user-defined custom types, tokio-postgres fires an additional query // when parameterizing values of type enum to know which custom type the value refers to. // Casting the enum value to `TEXT` avoid this roundtrip since `TEXT` is a builtin type. if let Some(enum_name) = name.clone() { - self.surround_with("ARRAY[", "]", |s| { - for (i, variant) in variants.into_iter().enumerate() { - s.add_parameter(variant.into_text()); - s.parameter_substitution()?; - s.write("::text")?; - - if i < (len - 1) { - s.write(", ")?; - } + self.add_parameter(Value::array(variants.into_iter().map(|v| v.into_text()))); + + self.surround_with("CAST(", ")", |s| { + s.parameter_substitution()?; + s.write("::text[]")?; + s.write(" AS ")?; + + if let Some(schema_name) = enum_name.schema_name { + s.surround_with_backticks(schema_name.deref())?; + s.write(".")? } + s.surround_with_backticks(enum_name.name.deref())?; + s.write("[]")?; + Ok(()) })?; - - self.write("::")?; - if let Some(schema_name) = enum_name.schema_name { - self.surround_with_backticks(schema_name.deref())?; - self.write(".")? - } - self.surround_with_backticks(enum_name.name.deref())?; - self.write("[]")?; } else { self.visit_parameterized(Value::array( variants.into_iter().map(|variant| variant.into_enum(name.clone())), diff --git a/query-engine/black-box-tests/tests/metrics/smoke_tests.rs b/query-engine/black-box-tests/tests/metrics/smoke_tests.rs index 5ff7ec8ad9ba..69207f3fff5d 100644 --- a/query-engine/black-box-tests/tests/metrics/smoke_tests.rs +++ b/query-engine/black-box-tests/tests/metrics/smoke_tests.rs @@ -17,7 +17,7 @@ mod smoke_tests { fn assert_value_in_range(metrics: &str, metric: &str, low: f64, high: f64) { let regex = Regex::new(format!(r"{metric}\s+([+-]?\d+(\.\d+)?)").as_str()).unwrap(); - match regex.captures(&metrics) { + match regex.captures(metrics) { Some(capture) => { let value = capture.get(1).unwrap().as_str().parse::().unwrap(); assert!( diff --git a/query-engine/connector-test-kit-rs/README.md b/query-engine/connector-test-kit-rs/README.md index 97d19467879a..993f636e0d28 100644 --- a/query-engine/connector-test-kit-rs/README.md +++ b/query-engine/connector-test-kit-rs/README.md @@ -82,15 +82,16 @@ drivers the code that actually communicates with the databases. See [`adapter-*` To run tests through a driver adapters, you should also configure the following environment variables: -* `EXTERNAL_TEST_EXECUTOR`: tells the query engine test kit to use an external process to run the queries, this is a node process running a program that will read the queries to run from STDIN, and return responses to STDOUT. The connector kit follows a protocol over JSON RPC for this communication. * `DRIVER_ADAPTER`: tells the test executor to use a particular driver adapter. Set to `neon`, `planetscale` or any other supported adapter. * `DRIVER_ADAPTER_CONFIG`: a json string with the configuration for the driver adapter. This is adapter specific. See the [github workflow for driver adapter tests](.github/workflows/query-engine-driver-adapters.yml) for examples on how to configure the driver adapters. +* `ENGINE`: can be used to run either `wasm` or `napi` version of the engine. Example: ```shell export EXTERNAL_TEST_EXECUTOR="$WORKSPACE_ROOT/query-engine/driver-adapters/connector-test-kit-executor/script/start_node.sh" export DRIVER_ADAPTER=neon +export ENGINE=wasm export DRIVER_ADAPTER_CONFIG ='{ "proxyUrl": "127.0.0.1:5488/v1" }' ```` @@ -98,7 +99,7 @@ We have provided helpers to run the query-engine tests with driver adapters, the variables for you: ```shell -DRIVER_ADAPTER=$adapter make test-qe +DRIVER_ADAPTER=$adapter ENGINE=$engine make test-qe ``` Where `$adapter` is one of the supported adapters: `neon`, `planetscale`, `libsql`. diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs index 8a2cbc7f24a2..0714015efd06 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs @@ -20,6 +20,7 @@ mod prisma_17103; mod prisma_18517; mod prisma_20799; mod prisma_21369; +mod prisma_21901; mod prisma_5952; mod prisma_6173; mod prisma_7010; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs index ccf04dd2f4af..179011108cb7 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs @@ -46,7 +46,7 @@ mod conversion_error { runner, r#"query { findManyTestModel { field } }"#, 2023, - "Inconsistent column data: Conversion failed: number must be an integer in column 'field'" + "Inconsistent column data: Conversion failed: number must be an integer in column 'field', got '1.84467440724388e19'" ); Ok(()) @@ -74,7 +74,7 @@ mod conversion_error { runner, r#"query { findManyTestModel { field } }"#, 2023, - "Inconsistent column data: Conversion failed: number must be an i64 in column 'field'" + "Inconsistent column data: Conversion failed: number must be an integer in column 'field', got '1.84467440724388e19'" ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_21901.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_21901.rs new file mode 100644 index 000000000000..5b9dd4f46dcc --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_21901.rs @@ -0,0 +1,50 @@ +use indoc::indoc; +use query_engine_tests::*; + +#[test_suite(schema(schema), capabilities(Enums, ScalarLists), exclude(MongoDb))] +mod prisma_21901 { + fn schema() -> String { + let schema = indoc! { + r#"model Test { + #id(id, Int, @id) + colors Color[] + } + + enum Color { + red + blue + green + } + "# + }; + + schema.to_owned() + } + + // fixes https://github.com/prisma/prisma/issues/21901 + #[connector_test] + async fn test(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!( + runner, + r#"mutation { createOneTest(data: { id: 1, colors: ["red"] }) { colors } }"# + ), + @r###"{"data":{"createOneTest":{"colors":["red"]}}}"### + ); + + insta::assert_snapshot!( + run_query!(runner, fmt_execute_raw(r#"TRUNCATE TABLE "prisma_21901_test"."Test" CASCADE;"#, [])), + @r###"{"data":{"executeRaw":0}}"### + ); + + insta::assert_snapshot!( + run_query!( + runner, + r#"mutation { createOneTest(data: { id: 2, colors: ["blue"] }) { colors } }"# + ), + @r###"{"data":{"createOneTest":{"colors":["blue"]}}}"### + ); + + Ok(()) + } +} diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml index 088a0d4b2d34..f257d9e52162 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml @@ -10,7 +10,7 @@ once_cell = "1" qe-setup = { path = "../qe-setup" } request-handlers = { path = "../../request-handlers" } tokio.workspace = true -query-core = { path = "../../core" } +query-core = { path = "../../core", features = ["metrics"] } sql-query-connector = { path = "../../connectors/sql-query-connector" } query-engine = { path = "../../query-engine"} psl.workspace = true diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/config.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/config.rs index 4af4e763298a..07ceca784ff9 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/config.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/config.rs @@ -3,10 +3,25 @@ use crate::{ PostgresConnectorTag, SqlServerConnectorTag, SqliteConnectorTag, TestResult, VitessConnectorTag, }; use serde::Deserialize; -use std::{convert::TryFrom, env, fs::File, io::Read, path::PathBuf}; +use std::{convert::TryFrom, env, fmt::Display, fs::File, io::Read, path::PathBuf}; static TEST_CONFIG_FILE_NAME: &str = ".test_config"; +#[derive(Debug, Deserialize, Clone)] +pub enum TestExecutor { + Napi, + Wasm, +} + +impl Display for TestExecutor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TestExecutor::Napi => f.write_str("Napi"), + TestExecutor::Wasm => f.write_str("Wasm"), + } + } +} + /// The central test configuration. #[derive(Debug, Default, Deserialize)] pub struct TestConfig { @@ -24,8 +39,9 @@ pub struct TestConfig { /// Used when testing driver adapters, this process is expected to be a javascript process /// loading the library engine (as a library, or WASM modules) and providing it with a /// driver adapter. + /// Possible values: Napi, Wasm /// Env key: `EXTERNAL_TEST_EXECUTOR` - external_test_executor: Option, + external_test_executor: Option, /// The driver adapter to use when running tests, will be forwarded to the external test /// executor by setting the `DRIVER_ADAPTER` env var when spawning the executor process @@ -85,12 +101,11 @@ fn exit_with_message(msg: &str) -> ! { impl TestConfig { /// Loads a configuration. File-based config has precedence over env config. pub(crate) fn load() -> Self { - let mut config = match Self::from_file().or_else(Self::from_env) { + let config = match Self::from_file().or_else(Self::from_env) { Some(config) => config, None => exit_with_message(CONFIG_LOAD_FAILED), }; - config.fill_defaults(); config.validate(); config.log_info(); @@ -107,8 +122,8 @@ impl TestConfig { self.connector_version().unwrap_or_default() ); println!("* CI? {}", self.is_ci); - if self.external_test_executor.as_ref().is_some() { - println!("* External test executor: {}", self.external_test_executor().unwrap_or_default()); + if let Some(external_test_executor) = self.external_test_executor.as_ref() { + println!("* External test executor: {}", external_test_executor); println!("* Driver adapter: {}", self.driver_adapter().unwrap_or_default()); println!("* Driver adapter url override: {}", self.json_stringify_driver_adapter_config()); } @@ -118,7 +133,10 @@ impl TestConfig { fn from_env() -> Option { let connector = std::env::var("TEST_CONNECTOR").ok(); let connector_version = std::env::var("TEST_CONNECTOR_VERSION").ok(); - let external_test_executor = std::env::var("EXTERNAL_TEST_EXECUTOR").ok(); + let external_test_executor = std::env::var("EXTERNAL_TEST_EXECUTOR") + .map(|value| serde_json::from_str::(&value).ok()) + .unwrap_or_default(); + let driver_adapter = std::env::var("DRIVER_ADAPTER").ok(); let driver_adapter_config = std::env::var("DRIVER_ADAPTER_CONFIG") .map(|config| serde_json::from_str::(config.as_str()).ok()) @@ -155,31 +173,24 @@ impl TestConfig { }) } - /// if the loaded value for external_test_executor is "default" (case insensitive), - /// and the workspace_root is set, then use the default external test executor. - fn fill_defaults(&mut self) { + fn workspace_root() -> Option { + env::var("WORKSPACE_ROOT").ok().map(PathBuf::from) + } + + pub fn external_test_executor_path(&self) -> Option { const DEFAULT_TEST_EXECUTOR: &str = "query-engine/driver-adapters/connector-test-kit-executor/script/start_node.sh"; - - if self - .external_test_executor + self.external_test_executor .as_ref() - .filter(|s| s.eq_ignore_ascii_case("default")) - .is_some() - { - self.external_test_executor = Self::workspace_root() - .map(|path| path.join(DEFAULT_TEST_EXECUTOR)) - .or_else(|| { + .and_then(|_| { + Self::workspace_root().or_else(|| { exit_with_message( "WORKSPACE_ROOT needs to be correctly set to the root of the prisma-engines repository", ) }) - .and_then(|path| path.to_str().map(|s| s.to_owned())); - } - } - - fn workspace_root() -> Option { - env::var("WORKSPACE_ROOT").ok().map(PathBuf::from) + }) + .map(|path| path.join(DEFAULT_TEST_EXECUTOR)) + .and_then(|path| path.to_str().map(|s| s.to_owned())) } fn validate(&self) { @@ -206,7 +217,7 @@ impl TestConfig { Err(err) => exit_with_message(&err.to_string()), } - if let Some(file) = self.external_test_executor.as_ref() { + if let Some(file) = self.external_test_executor_path().as_ref() { let path = PathBuf::from(file); let md = path.metadata(); if !path.exists() || md.is_err() || !md.unwrap().is_file() { @@ -259,8 +270,8 @@ impl TestConfig { self.is_ci } - pub fn external_test_executor(&self) -> Option<&str> { - self.external_test_executor.as_deref() + pub fn external_test_executor(&self) -> Option { + self.external_test_executor.clone() } pub fn driver_adapter(&self) -> Option<&str> { @@ -294,11 +305,16 @@ impl TestConfig { vec!( ( "DRIVER_ADAPTER".to_string(), - self.driver_adapter.clone().unwrap_or_default()), + self.driver_adapter.clone().unwrap_or_default() + ), ( "DRIVER_ADAPTER_CONFIG".to_string(), self.json_stringify_driver_adapter_config() ), + ( + "EXTERNAL_TEST_EXECUTOR".to_string(), + self.external_test_executor.clone().unwrap_or(TestExecutor::Napi).to_string(), + ), ( "PRISMA_DISABLE_QUAINT_EXECUTORS".to_string(), "1".to_string(), diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs index 583d5058c62e..1abfedbaf8ee 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs @@ -74,7 +74,7 @@ impl ExecutorProcess { }; self.task_handle.send((method_call, sender)).await?; - let raw_response = receiver.await?; + let raw_response = receiver.await??; tracing::debug!(%raw_response); let response = serde_json::from_value(raw_response)?; Ok(response) @@ -91,14 +91,17 @@ pub(super) static EXTERNAL_PROCESS: Lazy = } }); -type ReqImpl = (jsonrpc_core::MethodCall, oneshot::Sender); +type ReqImpl = ( + jsonrpc_core::MethodCall, + oneshot::Sender>, +); fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { use std::process::Stdio; use tokio::process::Command; let path = crate::CONFIG - .external_test_executor() + .external_test_executor_path() .unwrap_or_else(|| exit_with_message(1, "start_rpc_thread() error: external test executor is not set")); tokio::runtime::Builder::new_current_thread() @@ -106,7 +109,7 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { .build() .unwrap() .block_on(async move { - let process = match Command::new(path) + let process = match Command::new(&path) .envs(CONFIG.for_external_executor()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) @@ -119,7 +122,7 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { let mut stdout = BufReader::new(process.stdout.unwrap()).lines(); let mut stdin = process.stdin.unwrap(); - let mut pending_requests: HashMap> = + let mut pending_requests: HashMap>> = HashMap::new(); loop { @@ -140,10 +143,11 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { // The other end may be dropped if the whole // request future was dropped and not polled to // completion, so we ignore send errors here. - _ = sender.send(success.result); + _ = sender.send(Ok(success.result)); } jsonrpc_core::Output::Failure(err) => { - panic!("error response from jsonrpc: {err:?}") + tracing::error!("error response from jsonrpc: {err:?}"); + _ = sender.send(Err(Box::new(err.error))); } } } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs index 8c21dd93f903..ecc055d5d8d2 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs @@ -196,7 +196,6 @@ pub(crate) fn connection_string( None => unreachable!("A versioned connector must have a concrete version to run."), } } - ConnectorVersion::Vitess(Some(VitessVersion::V5_7)) => "mysql://root@localhost:33577/test".into(), ConnectorVersion::Vitess(Some(VitessVersion::V8_0)) => "mysql://root@localhost:33807/test".into(), ConnectorVersion::Vitess(None) => unreachable!("A versioned connector must have a concrete version to run."), } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs index 7afb78bab630..0376f45abbcf 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs @@ -33,7 +33,6 @@ impl ConnectorTagInterface for VitessConnectorTag { #[derive(Debug, Clone, Copy, PartialEq)] pub enum VitessVersion { - V5_7, V8_0, } @@ -42,7 +41,6 @@ impl FromStr for VitessVersion { fn from_str(s: &str) -> Result { let version = match s { - "5.7" => Self::V5_7, "8.0" => Self::V8_0, _ => return Err(TestError::parse_error(format!("Unknown Vitess version `{s}`"))), }; @@ -54,7 +52,6 @@ impl FromStr for VitessVersion { impl Display for VitessVersion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::V5_7 => write!(f, "5.7"), Self::V8_0 => write!(f, "8.0"), } } diff --git a/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite b/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite index 9638e3a22840..d1532fc12584 100644 --- a/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite +++ b/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite @@ -1,5 +1,5 @@ { "connector": "sqlite", "driver_adapter": "libsql", - "external_test_executor": "default" + "external_test_executor": "Napi" } \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite-wasm b/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite-wasm new file mode 100644 index 000000000000..b93966875dea --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/libsql-sqlite-wasm @@ -0,0 +1,5 @@ +{ + "connector": "sqlite", + "driver_adapter": "libsql", + "external_test_executor": "Wasm" +} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13 b/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13 index 0097d8c91f57..bb2034d0e460 100644 --- a/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13 +++ b/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13 @@ -3,5 +3,5 @@ "version": "13", "driver_adapter": "neon:ws", "driver_adapter_config": { "proxyUrl": "127.0.0.1:5488/v1" }, - "external_test_executor": "default" + "external_test_executor": "Napi" } \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13-wasm b/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13-wasm new file mode 100644 index 000000000000..6b1e9c0d1286 --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/neon-ws-postgres13-wasm @@ -0,0 +1,7 @@ +{ + "connector": "postgres", + "version": "13", + "driver_adapter": "neon:ws", + "driver_adapter_config": { "proxyUrl": "127.0.0.1:5488/v1" }, + "external_test_executor": "Wasm" +} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/pg-postgres13 b/query-engine/connector-test-kit-rs/test-configs/pg-postgres13 index 00f0c75ed736..4a2653dd3d2e 100644 --- a/query-engine/connector-test-kit-rs/test-configs/pg-postgres13 +++ b/query-engine/connector-test-kit-rs/test-configs/pg-postgres13 @@ -2,5 +2,5 @@ "connector": "postgres", "version": "13", "driver_adapter": "pg", - "external_test_executor": "default" + "external_test_executor": "Napi" } \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/pg-postgres13-wasm b/query-engine/connector-test-kit-rs/test-configs/pg-postgres13-wasm new file mode 100644 index 000000000000..b5d8ac3c7b15 --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/pg-postgres13-wasm @@ -0,0 +1,6 @@ +{ + "connector": "postgres", + "version": "13", + "driver_adapter": "pg", + "external_test_executor": "Wasm" +} \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8 b/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8 index 48c89c79427c..b823cc106997 100644 --- a/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8 +++ b/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8 @@ -3,5 +3,5 @@ "version": "8.0", "driver_adapter": "planetscale", "driver_adapter_config": { "proxyUrl": "http://root:root@127.0.0.1:8085" }, - "external_test_executor": "default" + "external_test_executor": "Napi" } diff --git a/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8-wasm b/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8-wasm new file mode 100644 index 000000000000..d4ee0627759a --- /dev/null +++ b/query-engine/connector-test-kit-rs/test-configs/planetscale-vitess8-wasm @@ -0,0 +1,7 @@ +{ + "connector": "vitess", + "version": "8.0", + "driver_adapter": "planetscale", + "driver_adapter_config": { "proxyUrl": "http://root:root@127.0.0.1:8085" }, + "external_test_executor": "Wasm" +} diff --git a/query-engine/connector-test-kit-rs/test-configs/vitess_5_7 b/query-engine/connector-test-kit-rs/test-configs/vitess_5_7 deleted file mode 100644 index 64fb5162ac41..000000000000 --- a/query-engine/connector-test-kit-rs/test-configs/vitess_5_7 +++ /dev/null @@ -1,3 +0,0 @@ -{ - "connector": "vitess", - "version": "5.7"} \ No newline at end of file diff --git a/query-engine/connectors/query-connector/Cargo.toml b/query-engine/connectors/query-connector/Cargo.toml index d16771aa3daf..788b8ca65576 100644 --- a/query-engine/connectors/query-connector/Cargo.toml +++ b/query-engine/connectors/query-connector/Cargo.toml @@ -14,6 +14,6 @@ prisma-value = {path = "../../../libs/prisma-value"} serde.workspace = true serde_json.workspace = true thiserror = "1.0" -user-facing-errors = {path = "../../../libs/user-facing-errors"} +user-facing-errors = {path = "../../../libs/user-facing-errors", features = ["sql"]} uuid = "1" indexmap = "1.7" diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index 62d0be640761..9ed0b4070056 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -5,6 +5,8 @@ version = "0.1.0" [features] vendored-openssl = ["quaint/vendored-openssl"] + +# Enable Driver Adapters driver-adapters = [] [dependencies] @@ -18,15 +20,20 @@ once_cell = "1.3" rand = "0.7" serde_json = {version = "1.0", features = ["float_roundtrip"]} thiserror = "1.0" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time"] } tracing = "0.1" tracing-futures = "0.2" uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" -quaint.workspace = true cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +quaint.workspace = true + +[target.'cfg(target_arch = "wasm32")'.dependencies] +quaint = { path = "../../../quaint" } + [dependencies.connector-interface] package = "query-connector" path = "../query-connector" diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index 0247e8c4b601..7895e838399a 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use super::{catch, transaction::SqlConnectorTransaction}; use crate::{database::operations::*, Context, SqlError}; use async_trait::async_trait; diff --git a/query-engine/connectors/sql-query-connector/src/database/mod.rs b/query-engine/connectors/sql-query-connector/src/database/mod.rs index 695db13b6620..e693769373b0 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/database/mod.rs @@ -1,12 +1,16 @@ mod connection; #[cfg(feature = "driver-adapters")] mod js; -mod mssql; -mod mysql; -mod postgresql; -mod sqlite; mod transaction; +#[cfg(not(target_arch = "wasm32"))] +pub(crate) mod native { + pub(crate) mod mssql; + pub(crate) mod mysql; + pub(crate) mod postgresql; + pub(crate) mod sqlite; +} + pub(crate) mod operations; use async_trait::async_trait; @@ -14,10 +18,9 @@ use connector_interface::{error::ConnectorError, Connector}; #[cfg(feature = "driver-adapters")] pub use js::*; -pub use mssql::*; -pub use mysql::*; -pub use postgresql::*; -pub use sqlite::*; + +#[cfg(not(target_arch = "wasm32"))] +pub use native::{mssql::*, mysql::*, postgresql::*, sqlite::*}; #[async_trait] pub trait FromSource { diff --git a/query-engine/connectors/sql-query-connector/src/database/mssql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs similarity index 94% rename from query-engine/connectors/sql-query-connector/src/database/mssql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/mssql.rs index 9655d205e4ca..19d3580bba9f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mssql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -60,7 +60,7 @@ impl FromSource for Mssql { #[async_trait] impl Connector for Mssql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/mysql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs similarity index 95% rename from query-engine/connectors/sql-query-connector/src/database/mysql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/mysql.rs index deb3e6a4f35f..477d687b995b 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mysql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -65,7 +65,7 @@ impl FromSource for Mysql { #[async_trait] impl Connector for Mysql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let runtime_conn = self.pool.check_out().await?; // Note: `runtime_conn` must be `Sized`, as that's required by `TransactionCapable` diff --git a/query-engine/connectors/sql-query-connector/src/database/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs similarity index 95% rename from query-engine/connectors/sql-query-connector/src/database/postgresql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 242b2b63090e..0e49a1de8bbd 100644 --- a/query-engine/connectors/sql-query-connector/src/database/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -67,7 +67,7 @@ impl FromSource for PostgreSql { #[async_trait] impl Connector for PostgreSql { async fn get_connection<'a>(&'a self) -> connector_interface::Result> { - super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); Ok(Box::new(conn) as Box) diff --git a/query-engine/connectors/sql-query-connector/src/database/sqlite.rs b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs similarity index 96% rename from query-engine/connectors/sql-query-connector/src/database/sqlite.rs rename to query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs index 6be9faeac54d..e38bccb861f4 100644 --- a/query-engine/connectors/sql-query-connector/src/database/sqlite.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -80,7 +80,7 @@ fn invalid_file_path_error(file_path: &str, connection_info: &ConnectionInfo) -> #[async_trait] impl Connector for Sqlite { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info().clone(), async move { + catch(self.connection_info().clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, self.connection_info(), self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 425f4ac1d4b3..611557c4f3ba 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -18,9 +18,28 @@ use std::{ ops::Deref, usize, }; -use tracing::log::trace; use user_facing_errors::query_engine::DatabaseConstraint; +#[cfg(target_arch = "wasm32")] +macro_rules! trace { + (target: $target:expr, $($arg:tt)+) => {{ + // No-op in WebAssembly + }}; + ($($arg:tt)+) => {{ + // No-op in WebAssembly + }}; +} + +#[cfg(not(target_arch = "wasm32"))] +macro_rules! trace { + (target: $target:expr, $($arg:tt)+) => { + tracing::log::trace!(target: $target, $($arg)+); + }; + ($($arg:tt)+) => { + tracing::log::trace!($($arg)+); + }; +} + async fn generate_id( conn: &dyn Queryable, id_field: &FieldSelection, diff --git a/query-engine/connectors/sql-query-connector/src/lib.rs b/query-engine/connectors/sql-query-connector/src/lib.rs index ed1528ded6b5..74c0a4aab5d3 100644 --- a/query-engine/connectors/sql-query-connector/src/lib.rs +++ b/query-engine/connectors/sql-query-connector/src/lib.rs @@ -22,9 +22,12 @@ mod value_ext; use self::{column_metadata::*, context::Context, query_ext::QueryExt, row::*}; use quaint::prelude::Queryable; +pub use database::FromSource; #[cfg(feature = "driver-adapters")] pub use database::{activate_driver_adapter, Js}; -pub use database::{FromSource, Mssql, Mysql, PostgreSql, Sqlite}; pub use error::SqlError; +#[cfg(not(target_arch = "wasm32"))] +pub use database::{Mssql, Mysql, PostgreSql, Sqlite}; + type Result = std::result::Result; diff --git a/query-engine/core-tests/Cargo.toml b/query-engine/core-tests/Cargo.toml index 9a2c3f5686eb..bac9219c3522 100644 --- a/query-engine/core-tests/Cargo.toml +++ b/query-engine/core-tests/Cargo.toml @@ -9,7 +9,7 @@ edition = "2021" dissimilar = "1.0.4" user-facing-errors = { path = "../../libs/user-facing-errors" } request-handlers = { path = "../request-handlers" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } schema = { path = "../schema" } psl.workspace = true serde_json.workspace = true diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index caadf6cdba00..9e0f03517cb5 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -3,6 +3,9 @@ edition = "2021" name = "query-core" version = "0.1.0" +[features] +metrics = ["query-engine-metrics"] + [dependencies] async-trait = "0.1" bigdecimal = "0.3" @@ -18,11 +21,11 @@ once_cell = "1" petgraph = "0.4" prisma-models = { path = "../prisma-models", features = ["default_generators"] } opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } -query-engine-metrics = {path = "../metrics"} +query-engine-metrics = { path = "../metrics", optional = true } serde.workspace = true serde_json.workspace = true thiserror = "1.0" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time"] } tracing = { version = "0.1", features = ["attributes"] } tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -34,3 +37,9 @@ schema = { path = "../schema" } lru = "0.7.7" enumflags2 = "0.7" +pin-project = "1" +wasm-bindgen-futures = "0.4" + +[target.'cfg(target_arch = "wasm32")'.dependencies] +pin-project = "1" +wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index 06452fcdd865..6ba21d37f9ff 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(unused_variables))] + use super::pipeline::QueryPipeline; use crate::{ executor::request_context, protocol::EngineProtocol, CoreError, IrSerializer, Operation, QueryGraph, @@ -5,9 +7,12 @@ use crate::{ }; use connector::{Connection, ConnectionLike, Connector}; use futures::future; + +#[cfg(feature = "metrics")] use query_engine_metrics::{ histogram, increment_counter, metrics, PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, PRISMA_CLIENT_QUERIES_TOTAL, }; + use schema::{QuerySchema, QuerySchemaRef}; use std::time::{Duration, Instant}; use tracing::Instrument; @@ -24,6 +29,7 @@ pub async fn execute_single_operation( let (graph, serializer) = build_graph(&query_schema, operation.clone())?; let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id).await; + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); result @@ -45,6 +51,8 @@ pub async fn execute_many_operations( for (i, (graph, serializer)) in queries.into_iter().enumerate() { let operation_timer = Instant::now(); let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); match result { @@ -98,6 +106,7 @@ pub async fn execute_many_self_contained( let dispatcher = crate::get_current_dispatcher(); for op in operations { + #[cfg(feature = "metrics")] increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); let conn_span = info_span!( @@ -158,6 +167,7 @@ async fn execute_self_contained( execute_self_contained_without_retry(conn, graph, serializer, force_transactions, &query_schema, trace_id).await }; + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); result @@ -259,6 +269,7 @@ async fn execute_on<'a>( query_schema: &'a QuerySchema, trace_id: Option, ) -> crate::Result { + #[cfg(feature = "metrics")] increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); let interpreter = QueryInterpreter::new(conn); diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index ddbb7dfc8429..ba2784d3c71a 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -10,6 +10,7 @@ mod execute_operation; mod interpreting_executor; mod pipeline; mod request_context; +pub(crate) mod task; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; diff --git a/query-engine/core/src/executor/task.rs b/query-engine/core/src/executor/task.rs new file mode 100644 index 000000000000..8d1c39bbcd06 --- /dev/null +++ b/query-engine/core/src/executor/task.rs @@ -0,0 +1,59 @@ +//! This module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. + +pub use arch::{spawn, JoinHandle}; +use futures::Future; + +// On native targets, `tokio::spawn` spawns a new asynchronous task. +#[cfg(not(target_arch = "wasm32"))] +mod arch { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } +} + +// On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. +#[cfg(target_arch = "wasm32")] +mod arch { + use super::*; + use tokio::sync::oneshot::{self}; + + // Wasm-compatible alternative to `tokio::task::JoinHandle`. + // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. + pub struct JoinHandle(oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + core::pin::Pin::new(&mut self.0).poll(cx) + } + } + + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop on Wasm targets + } + } + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + sender.send(result).ok(); + }); + JoinHandle(receiver) + } +} diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index 98208343d28a..105733be4166 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -1,3 +1,4 @@ +use crate::executor::task::JoinHandle; use crate::{protocol::EngineProtocol, ClosedTx, Operation, ResponseData}; use connector::Connection; use lru::LruCache; @@ -9,7 +10,6 @@ use tokio::{ mpsc::{channel, Sender}, RwLock, }, - task::JoinHandle, time::Duration, }; diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 88402d86fedd..104ffc26812f 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -1,7 +1,8 @@ use super::{CachedTx, TransactionError, TxOpRequest, TxOpRequestMsg, TxOpResponse}; +use crate::executor::task::{spawn, JoinHandle}; use crate::{ - execute_many_operations, execute_single_operation, protocol::EngineProtocol, - telemetry::helpers::set_span_link_from_traceparent, ClosedTx, Operation, ResponseData, TxId, + execute_many_operations, execute_single_operation, protocol::EngineProtocol, ClosedTx, Operation, ResponseData, + TxId, }; use connector::Connection; use schema::QuerySchemaRef; @@ -11,13 +12,15 @@ use tokio::{ mpsc::{channel, Receiver, Sender}, oneshot, RwLock, }, - task::JoinHandle, time::{self, Duration, Instant}, }; use tracing::Span; use tracing_futures::Instrument; use tracing_futures::WithSubscriber; +#[cfg(feature = "metrics")] +use crate::telemetry::helpers::set_span_link_from_traceparent; + #[derive(PartialEq)] enum RunState { Continue, @@ -81,6 +84,8 @@ impl<'a> ITXServer<'a> { traceparent: Option, ) -> crate::Result { let span = info_span!("prisma:engine:itx_query_builder", user_facing = true); + + #[cfg(feature = "metrics")] set_span_link_from_traceparent(&span, traceparent.clone()); let conn = self.cached_tx.as_open()?; @@ -267,7 +272,7 @@ pub(crate) async fn spawn_itx_actor( }; let (open_transaction_send, open_transaction_rcv) = oneshot::channel(); - tokio::task::spawn( + spawn( crate::executor::with_request_context(engine_protocol, async move { // We match on the result in order to send the error to the parent task and abort this // task, on error. This is a separate task (actor), not a function where we can just bubble up the @@ -380,7 +385,7 @@ pub(crate) fn spawn_client_list_clear_actor( closed_txs: Arc>>>, mut rx: Receiver<(TxId, Option)>, ) -> JoinHandle<()> { - tokio::task::spawn(async move { + spawn(async move { loop { if let Some((id, closed_tx)) = rx.recv().await { trace!("removing {} from client list", id); diff --git a/query-engine/core/src/lib.rs b/query-engine/core/src/lib.rs index 7970c96139b7..38f39e9fb5d9 100644 --- a/query-engine/core/src/lib.rs +++ b/query-engine/core/src/lib.rs @@ -9,6 +9,8 @@ pub mod protocol; pub mod query_document; pub mod query_graph_builder; pub mod response_ir; + +#[cfg(feature = "metrics")] pub mod telemetry; pub use self::{ @@ -16,8 +18,11 @@ pub use self::{ executor::{QueryExecutor, TransactionOptions}, interactive_transactions::{ExtendedTransactionUserFacingError, TransactionError, TxId}, query_document::*, - telemetry::*, }; + +#[cfg(feature = "metrics")] +pub use self::telemetry::*; + pub use connector::{ error::{ConnectorError, ErrorKind as ConnectorErrorKind}, Connector, diff --git a/query-engine/driver-adapters/connector-test-kit-executor/package.json b/query-engine/driver-adapters/connector-test-kit-executor/package.json index 4648887f5063..e872d5684450 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/package.json +++ b/query-engine/driver-adapters/connector-test-kit-executor/package.json @@ -12,6 +12,11 @@ "scripts": { "build": "tsup ./src/index.ts --format esm --dts" }, + "tsup": { + "external": [ + "../../../query-engine-wasm/pkg/query_engine_bg.js" + ] + }, "keywords": [], "author": "", "sideEffects": false, diff --git a/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts b/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts index 2318c0525760..3ea8aaf147b9 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts +++ b/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts @@ -1,5 +1,4 @@ import * as qe from './qe' -import * as engines from './engines/Library' import * as readline from 'node:readline' import * as jsonRpc from './jsonRpc' @@ -76,7 +75,7 @@ async function main(): Promise { } const state: Record = {} @@ -215,10 +214,10 @@ function respondOk(requestId: number, payload: unknown) { console.log(JSON.stringify(msg)) } -async function initQe(url: string, prismaSchema: string, logCallback: qe.QueryLogCallback): Promise<[engines.QueryEngineInstance, ErrorCapturingDriverAdapter]> { +async function initQe(url: string, prismaSchema: string, logCallback: qe.QueryLogCallback): Promise<[qe.QueryEngine, ErrorCapturingDriverAdapter]> { const adapter = await adapterFromEnv(url) as DriverAdapter const errorCapturingAdapter = bindAdapter(adapter) - const engineInstance = qe.initQueryEngine(errorCapturingAdapter, prismaSchema, logCallback, debug) + const engineInstance = await qe.initQueryEngine(errorCapturingAdapter, prismaSchema, logCallback, debug) return [engineInstance, errorCapturingAdapter]; } diff --git a/query-engine/driver-adapters/connector-test-kit-executor/src/qe.ts b/query-engine/driver-adapters/connector-test-kit-executor/src/qe.ts index 186d7a9e80d2..20e9a4917fb5 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/src/qe.ts +++ b/query-engine/driver-adapters/connector-test-kit-executor/src/qe.ts @@ -1,22 +1,24 @@ import type { ErrorCapturingDriverAdapter } from '@prisma/driver-adapter-utils' -import * as lib from './engines/Library' +import * as napi from './engines/Library' import * as os from 'node:os' import * as path from 'node:path' +import { fileURLToPath } from 'node:url' -export type QueryLogCallback = (log: string) => void +const dirname = path.dirname(fileURLToPath(import.meta.url)) -export function initQueryEngine(adapter: ErrorCapturingDriverAdapter, datamodel: string, queryLogCallback: QueryLogCallback, debug: (...args: any[]) => void): lib.QueryEngineInstance { - // I assume nobody will run this on Windows ¯\_(ツ)_/¯ - const libExt = os.platform() === 'darwin' ? 'dylib' : 'so' - const dirname = path.dirname(new URL(import.meta.url).pathname) +export interface QueryEngine { + connect(trace: string): Promise + disconnect(trace: string): Promise; + query(body: string, trace: string, tx_id?: string): Promise; + startTransaction(input: string, trace: string): Promise; + commitTransaction(tx_id: string, trace: string): Promise; + rollbackTransaction(tx_id: string, trace: string): Promise; +} - const libQueryEnginePath = path.join(dirname, `../../../../target/debug/libquery_engine.${libExt}`) +export type QueryLogCallback = (log: string) => void - const libqueryEngine = { exports: {} as unknown as lib.Library } - // @ts-ignore - process.dlopen(libqueryEngine, libQueryEnginePath) - const QueryEngine = libqueryEngine.exports.QueryEngine +export async function initQueryEngine(adapter: ErrorCapturingDriverAdapter, datamodel: string, queryLogCallback: QueryLogCallback, debug: (...args: any[]) => void): QueryEngine { const queryEngineOptions = { datamodel, @@ -37,5 +39,29 @@ export function initQueryEngine(adapter: ErrorCapturingDriverAdapter, datamodel: debug(parsed) } - return new QueryEngine(queryEngineOptions, logCallback, adapter) + const engineFromEnv = process.env.EXTERNAL_TEST_EXECUTOR ?? 'Napi' + if (engineFromEnv === 'Wasm') { + const { WasmQueryEngine } = await import('./wasm') + return new WasmQueryEngine(queryEngineOptions, logCallback, adapter) + } else if (engineFromEnv === 'Napi') { + const { QueryEngine } = loadNapiEngine() + return new QueryEngine(queryEngineOptions, logCallback, adapter) + } else { + throw new TypeError(`Invalid EXTERNAL_TEST_EXECUTOR value: ${engineFromEnv}. Expected Napi or Wasm`) + } + + } + +function loadNapiEngine(): napi.Library { + // I assume nobody will run this on Windows ¯\_(ツ)_/¯ + const libExt = os.platform() === 'darwin' ? 'dylib' : 'so' + + const libQueryEnginePath = path.join(dirname, `../../../../target/debug/libquery_engine.${libExt}`) + + const libqueryEngine = { exports: {} as unknown as napi.Library } + // @ts-ignore + process.dlopen(libqueryEngine, libQueryEnginePath) + + return libqueryEngine.exports +} \ No newline at end of file diff --git a/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts b/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts new file mode 100644 index 000000000000..439fd0c3f94f --- /dev/null +++ b/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts @@ -0,0 +1,14 @@ +import * as wasm from '../../../query-engine-wasm/pkg/query_engine_bg.js' +import fs from 'node:fs/promises' +import path from 'node:path' +import { fileURLToPath } from 'node:url' + +const dirname = path.dirname(fileURLToPath(import.meta.url)) + +const bytes = await fs.readFile(path.resolve(dirname, '..', '..', '..', 'query-engine-wasm', 'pkg', 'query_engine_bg.wasm')) +const module = new WebAssembly.Module(bytes) +const instance = new WebAssembly.Instance(module, { './query_engine_bg.js': wasm }) +wasm.__wbg_set_wasm(instance.exports); +wasm.init() + +export const WasmQueryEngine = wasm.QueryEngine \ No newline at end of file diff --git a/query-engine/driver-adapters/connector-test-kit-executor/tsconfig.json b/query-engine/driver-adapters/connector-test-kit-executor/tsconfig.json index 516c114b3e15..20fc4bd62ff7 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/tsconfig.json +++ b/query-engine/driver-adapters/connector-test-kit-executor/tsconfig.json @@ -2,7 +2,7 @@ "compilerOptions": { "target": "ES2022", "module": "ESNext", - "lib": ["ES2022"], + "lib": ["ES2022", "DOM"], "moduleResolution": "Bundler", "esModuleInterop": false, "isolatedModules": true, @@ -17,7 +17,7 @@ "skipDefaultLibCheck": true, "skipLibCheck": true, "emitDeclarationOnly": true, - "resolveJsonModule": true + "resolveJsonModule": true, }, "exclude": ["**/dist", "**/declaration", "**/node_modules", "**/src/__tests__"] } \ No newline at end of file diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index a708d75c0e32..642c2491757a 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -853,7 +853,7 @@ mod proxy_test { let s = "13:02:20.321"; let json_value = serde_json::Value::String(s.to_string()); let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - let time: NaiveTime = NaiveTime::from_hms_milli_opt(13, 02, 20, 321).unwrap(); + let time: NaiveTime = NaiveTime::from_hms_milli_opt(13, 2, 20, 321).unwrap(); assert_eq!(quaint_value, QuaintValue::time(time)); } diff --git a/query-engine/query-engine-node-api/Cargo.toml b/query-engine/query-engine-node-api/Cargo.toml index 74f9686189fc..0eaed9eff7ce 100644 --- a/query-engine/query-engine-node-api/Cargo.toml +++ b/query-engine/query-engine-node-api/Cargo.toml @@ -16,7 +16,7 @@ driver-adapters = ["request-handlers/driver-adapters", "sql-connector/driver-ada [dependencies] anyhow = "1" async-trait = "0.1" -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } request-handlers = { path = "../request-handlers" } query-connector = { path = "../connectors/query-connector" } user-facing-errors = { path = "../../libs/user-facing-errors" } diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index a8bc393aee3f..fdccc773eaf3 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -14,15 +14,23 @@ async-trait = "0.1" user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true prisma-models = { path = "../prisma-models" } +quaint = { path = "../../quaint" } +request-handlers = { path = "../request-handlers", default-features = false, features = [ + "sql", + "driver-adapters", +] } +connector = { path = "../connectors/query-connector", package = "query-connector" } +sql-query-connector = { path = "../connectors/sql-query-connector" } +query-core = { path = "../core" } thiserror = "1" -connection-string.workspace = true +connection-string.workspace = true url = "2" serde_json.workspace = true serde.workspace = true tokio = { version = "1.25", features = ["macros", "sync", "io-util", "time"] } futures = "0.3" -wasm-bindgen = "=0.2.87" +wasm-bindgen = "=0.2.88" wasm-bindgen-futures = "0.4" serde-wasm-bindgen = "0.5" js-sys = "0.3" diff --git a/query-engine/query-engine-wasm/pnpm-lock.yaml b/query-engine/query-engine-wasm/pnpm-lock.yaml new file mode 100644 index 000000000000..89591aef9869 --- /dev/null +++ b/query-engine/query-engine-wasm/pnpm-lock.yaml @@ -0,0 +1,130 @@ +lockfileVersion: '6.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +dependencies: + '@neondatabase/serverless': + specifier: 0.6.0 + version: 0.6.0 + '@prisma/adapter-neon': + specifier: 5.6.0 + version: 5.6.0(@neondatabase/serverless@0.6.0) + '@prisma/driver-adapter-utils': + specifier: 5.6.0 + version: 5.6.0 + +packages: + + /@neondatabase/serverless@0.6.0: + resolution: {integrity: sha512-qXxBRYN0m2v8kVQBfMxbzNGn2xFAhTXFibzQlE++NfJ56Shz3m7+MyBBtXDlEH+3Wfa6lToDXf1MElocY4sJ3w==} + dependencies: + '@types/pg': 8.6.6 + dev: false + + /@prisma/adapter-neon@5.6.0(@neondatabase/serverless@0.6.0): + resolution: {integrity: sha512-IUkIE5NKyP2wCXMMAByM78fizfaJl7YeWDEajvyqQafXgRwmxl+2HhxsevvHly8jT4RlELdhjK6IP1eciGvXVA==} + peerDependencies: + '@neondatabase/serverless': ^0.6.0 + dependencies: + '@neondatabase/serverless': 0.6.0 + '@prisma/driver-adapter-utils': 5.6.0 + postgres-array: 3.0.2 + transitivePeerDependencies: + - supports-color + dev: false + + /@prisma/driver-adapter-utils@5.6.0: + resolution: {integrity: sha512-/TSrfCGLAQghNf+bwg5/e8iKAgecCYU/gMN0IyNra3183/VTQJneLFgbacuSK9bBXiIRUmpbuUIrJ6dhENzfjA==} + dependencies: + debug: 4.3.4 + transitivePeerDependencies: + - supports-color + dev: false + + /@types/node@20.9.1: + resolution: {integrity: sha512-HhmzZh5LSJNS5O8jQKpJ/3ZcrrlG6L70hpGqMIAoM9YVD0YBRNWYsfwcXq8VnSjlNpCpgLzMXdiPo+dxcvSmiA==} + dependencies: + undici-types: 5.26.5 + dev: false + + /@types/pg@8.6.6: + resolution: {integrity: sha512-O2xNmXebtwVekJDD+02udOncjVcMZQuTEQEMpKJ0ZRf5E7/9JJX3izhKUcUifBkyKpljyUM6BTgy2trmviKlpw==} + dependencies: + '@types/node': 20.9.1 + pg-protocol: 1.6.0 + pg-types: 2.2.0 + dev: false + + /debug@4.3.4: + resolution: {integrity: sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + dependencies: + ms: 2.1.2 + dev: false + + /ms@2.1.2: + resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} + dev: false + + /pg-int8@1.0.1: + resolution: {integrity: sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==} + engines: {node: '>=4.0.0'} + dev: false + + /pg-protocol@1.6.0: + resolution: {integrity: sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==} + dev: false + + /pg-types@2.2.0: + resolution: {integrity: sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==} + engines: {node: '>=4'} + dependencies: + pg-int8: 1.0.1 + postgres-array: 2.0.0 + postgres-bytea: 1.0.0 + postgres-date: 1.0.7 + postgres-interval: 1.2.0 + dev: false + + /postgres-array@2.0.0: + resolution: {integrity: sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==} + engines: {node: '>=4'} + dev: false + + /postgres-array@3.0.2: + resolution: {integrity: sha512-6faShkdFugNQCLwucjPcY5ARoW1SlbnrZjmGl0IrrqewpvxvhSLHimCVzqeuULCbG0fQv7Dtk1yDbG3xv7Veog==} + engines: {node: '>=12'} + dev: false + + /postgres-bytea@1.0.0: + resolution: {integrity: sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==} + engines: {node: '>=0.10.0'} + dev: false + + /postgres-date@1.0.7: + resolution: {integrity: sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==} + engines: {node: '>=0.10.0'} + dev: false + + /postgres-interval@1.2.0: + resolution: {integrity: sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==} + engines: {node: '>=0.10.0'} + dependencies: + xtend: 4.0.2 + dev: false + + /undici-types@5.26.5: + resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} + dev: false + + /xtend@4.0.2: + resolution: {integrity: sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==} + engines: {node: '>=0.4'} + dev: false diff --git a/query-engine/query-engine/Cargo.toml b/query-engine/query-engine/Cargo.toml index be36e4f842dc..c70d8590d0ff 100644 --- a/query-engine/query-engine/Cargo.toml +++ b/query-engine/query-engine/Cargo.toml @@ -20,7 +20,7 @@ enumflags2 = { version = "0.7"} psl.workspace = true graphql-parser = { git = "https://github.com/prisma/graphql-parser" } mongodb-connector = { path = "../connectors/mongodb-query-connector", optional = true, package = "mongodb-query-connector" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } request-handlers = { path = "../request-handlers" } serde.workspace = true serde_json.workspace = true diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index f5fb433b13ba..f04d742c448e 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" prisma-models = { path = "../prisma-models" } query-core = { path = "../core" } user-facing-errors = { path = "../../libs/user-facing-errors" } +quaint = { path = "../../quaint" } psl.workspace = true dmmf_crate = { path = "../dmmf", package = "dmmf" } itertools = "0.10" @@ -20,7 +21,6 @@ thiserror = "1" tracing = "0.1" url = "2" connection-string.workspace = true -quaint.workspace = true once_cell = "1.15" mongodb-query-connector = { path = "../connectors/mongodb-query-connector", optional = true } @@ -32,10 +32,11 @@ schema = { path = "../schema" } codspeed-criterion-compat = "1.1.0" [features] -default = ["mongodb", "sql"] +default = ["sql", "mongodb", "native"] mongodb = ["mongodb-query-connector"] sql = ["sql-query-connector"] -driver-adapters = ["sql-query-connector"] +driver-adapters = ["sql-query-connector/driver-adapters"] +native = ["mongodb", "sql-query-connector", "quaint/native", "query-core/metrics"] [[bench]] name = "query_planning_bench" diff --git a/query-engine/request-handlers/src/connector_mode.rs b/query-engine/request-handlers/src/connector_mode.rs index 00e0515a596e..be03fbab5820 100644 --- a/query-engine/request-handlers/src/connector_mode.rs +++ b/query-engine/request-handlers/src/connector_mode.rs @@ -1,6 +1,7 @@ #[derive(Copy, Clone, PartialEq, Eq)] pub enum ConnectorMode { /// Indicates that Rust drivers are used in Query Engine. + #[cfg(feature = "native")] Rust, /// Indicates that JS drivers are used in Query Engine. diff --git a/query-engine/request-handlers/src/load_executor.rs b/query-engine/request-handlers/src/load_executor.rs index 652ad3108f0d..26728605f92a 100644 --- a/query-engine/request-handlers/src/load_executor.rs +++ b/query-engine/request-handlers/src/load_executor.rs @@ -1,14 +1,12 @@ +#![allow(unused_imports)] + use psl::{builtin_connectors::*, Datasource, PreviewFeatures}; use query_core::{executor::InterpretingExecutor, Connector, QueryExecutor}; use sql_query_connector::*; use std::collections::HashMap; use std::env; -use tracing::trace; use url::Url; -#[cfg(feature = "mongodb")] -use mongodb_query_connector::MongoDb; - use super::ConnectorMode; /// Loads a query executor based on the parsed Prisma schema (datasource). @@ -27,6 +25,7 @@ pub async fn load( driver_adapter(source, url, features).await } + #[cfg(feature = "native")] ConnectorMode::Rust => { if let Ok(value) = env::var("PRISMA_DISABLE_QUAINT_EXECUTORS") { let disable = value.to_uppercase(); @@ -36,14 +35,14 @@ pub async fn load( } match source.active_provider { - p if SQLITE.is_provider(p) => sqlite(source, url, features).await, - p if MYSQL.is_provider(p) => mysql(source, url, features).await, - p if POSTGRES.is_provider(p) => postgres(source, url, features).await, - p if MSSQL.is_provider(p) => mssql(source, url, features).await, - p if COCKROACH.is_provider(p) => postgres(source, url, features).await, + p if SQLITE.is_provider(p) => native::sqlite(source, url, features).await, + p if MYSQL.is_provider(p) => native::mysql(source, url, features).await, + p if POSTGRES.is_provider(p) => native::postgres(source, url, features).await, + p if MSSQL.is_provider(p) => native::mssql(source, url, features).await, + p if COCKROACH.is_provider(p) => native::postgres(source, url, features).await, #[cfg(feature = "mongodb")] - p if MONGODB.is_provider(p) => mongodb(source, url, features).await, + p if MONGODB.is_provider(p) => native::mongodb(source, url, features).await, x => Err(query_core::CoreError::ConfigurationError(format!( "Unsupported connector type: {x}" @@ -53,57 +52,88 @@ pub async fn load( } } -async fn sqlite( +#[cfg(feature = "driver-adapters")] +async fn driver_adapter( source: &Datasource, url: &str, features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading SQLite query connector..."); - let sqlite = Sqlite::from_source(source, url, features).await?; - trace!("Loaded SQLite query connector."); - Ok(executor_for(sqlite, false)) +) -> Result, query_core::CoreError> { + let js = Js::from_source(source, url, features).await?; + Ok(executor_for(js, false)) } -async fn postgres( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading Postgres query connector..."); - let database_str = url; - let psql = PostgreSql::from_source(source, url, features).await?; - - let url = Url::parse(database_str) - .map_err(|err| query_core::CoreError::ConfigurationError(format!("Error parsing connection string: {err}")))?; - let params: HashMap = url.query_pairs().into_owned().collect(); - - let force_transactions = params - .get("pgbouncer") - .and_then(|flag| flag.parse().ok()) - .unwrap_or(false); - trace!("Loaded Postgres query connector."); - Ok(executor_for(psql, force_transactions)) -} +#[cfg(feature = "native")] +mod native { + use super::*; + use tracing::trace; + + pub(crate) async fn sqlite( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading SQLite query connector..."); + let sqlite = Sqlite::from_source(source, url, features).await?; + trace!("Loaded SQLite query connector."); + Ok(executor_for(sqlite, false)) + } -async fn mysql( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - let mysql = Mysql::from_source(source, url, features).await?; - trace!("Loaded MySQL query connector."); - Ok(executor_for(mysql, false)) -} + pub(crate) async fn postgres( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading Postgres query connector..."); + let database_str = url; + let psql = PostgreSql::from_source(source, url, features).await?; + + let url = Url::parse(database_str).map_err(|err| { + query_core::CoreError::ConfigurationError(format!("Error parsing connection string: {err}")) + })?; + let params: HashMap = url.query_pairs().into_owned().collect(); + + let force_transactions = params + .get("pgbouncer") + .and_then(|flag| flag.parse().ok()) + .unwrap_or(false); + trace!("Loaded Postgres query connector."); + Ok(executor_for(psql, force_transactions)) + } -async fn mssql( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading SQL Server query connector..."); - let mssql = Mssql::from_source(source, url, features).await?; - trace!("Loaded SQL Server query connector."); - Ok(executor_for(mssql, false)) + pub(crate) async fn mysql( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + let mysql = Mysql::from_source(source, url, features).await?; + trace!("Loaded MySQL query connector."); + Ok(executor_for(mysql, false)) + } + + pub(crate) async fn mssql( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading SQL Server query connector..."); + let mssql = Mssql::from_source(source, url, features).await?; + trace!("Loaded SQL Server query connector."); + Ok(executor_for(mssql, false)) + } + + #[cfg(feature = "mongodb")] + pub(crate) async fn mongodb( + source: &Datasource, + url: &str, + _features: PreviewFeatures, + ) -> query_core::Result> { + use mongodb_query_connector::MongoDb; + + trace!("Loading MongoDB query connector..."); + let mongo = MongoDb::new(source, url).await?; + trace!("Loaded MongoDB query connector."); + Ok(executor_for(mongo, false)) + } } fn executor_for(connector: T, force_transactions: bool) -> Box @@ -112,27 +142,3 @@ where { Box::new(InterpretingExecutor::new(connector, force_transactions)) } - -#[cfg(feature = "mongodb")] -async fn mongodb( - source: &Datasource, - url: &str, - _features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading MongoDB query connector..."); - let mongo = MongoDb::new(source, url).await?; - trace!("Loaded MongoDB query connector."); - Ok(executor_for(mongo, false)) -} - -#[cfg(feature = "driver-adapters")] -async fn driver_adapter( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> Result, query_core::CoreError> { - trace!("Loading driver adapter..."); - let js = Js::from_source(source, url, features).await?; - trace!("Loaded driver adapter..."); - Ok(executor_for(js, false)) -} diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs index 18a0b8e94b3c..51a8f5ef54be 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs @@ -23,7 +23,7 @@ impl SqlSchemaCalculatorFlavour for MssqlFlavour { } } - fn push_connector_data(&self, context: &mut super::super::Context<'_>) { + fn push_connector_data(&self, context: &mut crate::sql_schema_calculator::Context<'_>) { let mut data = MssqlSchemaExt::default(); for model in context.datamodel.db.walk_models() { diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs index 40577d68a35d..656fe432a970 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs @@ -37,7 +37,7 @@ impl SqlSchemaCalculatorFlavour for PostgresFlavour { } } - fn push_connector_data(&self, context: &mut super::super::Context<'_>) { + fn push_connector_data(&self, context: &mut crate::sql_schema_calculator::Context<'_>) { let mut postgres_ext = PostgresSchemaExt::default(); let db = &context.datamodel.db; diff --git a/schema-engine/sql-introspection-tests/tests/mysql/constraints.rs b/schema-engine/sql-introspection-tests/tests/mysql/constraints.rs index d712b17f684e..537e2233e9ee 100644 --- a/schema-engine/sql-introspection-tests/tests/mysql/constraints.rs +++ b/schema-engine/sql-introspection-tests/tests/mysql/constraints.rs @@ -4,7 +4,7 @@ use indoc::indoc; use sql_introspection_tests::test_api::*; use test_macros::test_connector; -// Note: MySQL 5.6 ad 5.7 do not support check constraints, so this test is only run on MySQL 8.0. +// Note: MySQL 5.6 and 5.7 do not support check constraints, so this test is only run on MySQL 8.0. #[test_connector(tags(Mysql8), exclude(Vitess))] async fn check_constraints_stopgap(api: &mut TestApi) -> TestResult { let raw_sql = indoc! {r#" diff --git a/schema-engine/sql-migration-tests/tests/native_types/mysql.rs b/schema-engine/sql-migration-tests/tests/native_types/mysql.rs index d8cf62f5767c..b74f3dd6bac4 100644 --- a/schema-engine/sql-migration-tests/tests/native_types/mysql.rs +++ b/schema-engine/sql-migration-tests/tests/native_types/mysql.rs @@ -697,8 +697,8 @@ fn filter_from_types(api: &TestApi, cases: Cases) -> Cow<'static, [Case]> { return Cow::Owned( cases .iter() + .filter(|&(ty, _, _)| !type_is_unsupported_mariadb(ty)) .cloned() - .filter(|(ty, _, _)| !type_is_unsupported_mariadb(ty)) .collect(), ); } @@ -707,8 +707,8 @@ fn filter_from_types(api: &TestApi, cases: Cases) -> Cow<'static, [Case]> { return Cow::Owned( cases .iter() + .filter(|&(ty, _, _)| !type_is_unsupported_mysql_5_6(ty)) .cloned() - .filter(|(ty, _, _)| !type_is_unsupported_mysql_5_6(ty)) .collect(), ); }