From 57be15a4cbee9b46f0bb0b7c1440f9acb39bf001 Mon Sep 17 00:00:00 2001 From: DongHun Kwak Date: Tue, 11 Apr 2023 14:43:19 +0900 Subject: [PATCH 1/1] Import rusqlite 0.28.0 --- .cargo_vcs_info.json | 6 + .gitignore | 3 + Cargo.toml | 233 +++ Cargo.toml.orig | 168 +++ LICENSE | 19 + README.md | 250 ++++ benches/cache.rs | 18 + benches/exec.rs | 17 + src/backup.rs | 428 ++++++ src/blob/mod.rs | 551 +++++++ src/blob/pos_io.rs | 274 ++++ src/busy.rs | 174 +++ src/cache.rs | 350 +++++ src/collation.rs | 215 +++ src/column.rs | 241 +++ src/config.rs | 156 ++ src/context.rs | 75 + src/error.rs | 445 ++++++ src/functions.rs | 1099 ++++++++++++++ src/hooks.rs | 815 ++++++++++ src/inner_connection.rs | 456 ++++++ src/lib.rs | 2127 +++++++++++++++++++++++++++ src/limits.rs | 169 +++ src/load_extension_guard.rs | 46 + src/params.rs | 458 ++++++ src/pragma.rs | 459 ++++++ src/raw_statement.rs | 241 +++ src/row.rs | 559 +++++++ src/session.rs | 938 ++++++++++++ src/statement.rs | 1555 ++++++++++++++++++++ src/trace.rs | 184 +++ src/transaction.rs | 759 ++++++++++ src/types/chrono.rs | 323 ++++ src/types/from_sql.rs | 276 ++++ src/types/mod.rs | 449 ++++++ src/types/serde_json.rs | 53 + src/types/time.rs | 168 +++ src/types/to_sql.rs | 429 ++++++ src/types/url.rs | 82 ++ src/types/value.rs | 142 ++ src/types/value_ref.rs | 263 ++++ src/unlock_notify.rs | 117 ++ src/util/mod.rs | 11 + src/util/param_cache.rs | 60 + src/util/small_cstr.rs | 170 +++ src/util/sqlite_string.rs | 236 +++ src/version.rs | 23 + src/vtab/array.rs | 223 +++ src/vtab/csvtab.rs | 396 +++++ src/vtab/mod.rs | 1366 +++++++++++++++++ src/vtab/series.rs | 319 ++++ src/vtab/vtablog.rs | 300 ++++ test.csv | 6 + tests/config_log.rs | 34 + tests/deny_single_threaded_sqlite_config.rs | 20 + tests/vtab.rs | 100 ++ 56 files changed, 19054 insertions(+) create mode 100644 .cargo_vcs_info.json create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 Cargo.toml.orig create mode 100644 LICENSE create mode 100644 README.md create mode 100644 benches/cache.rs create mode 100644 benches/exec.rs create mode 100644 src/backup.rs create mode 100644 src/blob/mod.rs create mode 100644 src/blob/pos_io.rs create mode 100644 src/busy.rs create mode 100644 src/cache.rs create mode 100644 src/collation.rs create mode 100644 src/column.rs create mode 100644 src/config.rs create mode 100644 src/context.rs create mode 100644 src/error.rs create mode 100644 src/functions.rs create mode 100644 src/hooks.rs create mode 100644 src/inner_connection.rs create mode 100644 src/lib.rs create mode 100644 src/limits.rs create mode 100644 src/load_extension_guard.rs create mode 100644 src/params.rs create mode 100644 src/pragma.rs create mode 100644 src/raw_statement.rs create mode 100644 src/row.rs create mode 100644 src/session.rs create mode 100644 src/statement.rs create mode 100644 src/trace.rs create mode 100644 src/transaction.rs create mode 100644 src/types/chrono.rs create mode 100644 src/types/from_sql.rs create mode 100644 src/types/mod.rs create mode 100644 src/types/serde_json.rs create mode 100644 src/types/time.rs create mode 100644 src/types/to_sql.rs create mode 100644 src/types/url.rs create mode 100644 src/types/value.rs create mode 100644 src/types/value_ref.rs create mode 100644 src/unlock_notify.rs create mode 100644 src/util/mod.rs create mode 100644 src/util/param_cache.rs create mode 100644 src/util/small_cstr.rs create mode 100644 src/util/sqlite_string.rs create mode 100644 src/version.rs create mode 100644 src/vtab/array.rs create mode 100644 src/vtab/csvtab.rs create mode 100644 src/vtab/mod.rs create mode 100644 src/vtab/series.rs create mode 100644 src/vtab/vtablog.rs create mode 100644 test.csv create mode 100644 tests/config_log.rs create mode 100644 tests/deny_single_threaded_sqlite_config.rs create mode 100644 tests/vtab.rs diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json new file mode 100644 index 0000000..96ed8b8 --- /dev/null +++ b/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "26293a11f595574897e7e5a5b639d1587255c6b9" + }, + "path_in_vcs": "" +} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5f0a3e1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target/ +/doc/ +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..08b3bc0 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,233 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2018" +name = "rusqlite" +version = "0.28.0" +authors = ["The rusqlite developers"] +exclude = [ + "/.github/*", + "/.gitattributes", + "/appveyor.yml", + "/Changelog.md", + "/clippy.toml", + "/codecov.yml", +] +description = "Ergonomic wrapper for SQLite" +documentation = "http://docs.rs/rusqlite/" +readme = "README.md" +keywords = [ + "sqlite", + "database", + "ffi", +] +categories = ["database"] +license = "MIT" +repository = "https://github.com/rusqlite/rusqlite" + +[package.metadata.docs.rs] +features = ["modern-full"] +all-features = false +no-default-features = true +default-target = "x86_64-unknown-linux-gnu" +rustdoc-args = [ + "--cfg", + "docsrs", +] + +[package.metadata.playground] +features = ["bundled-full"] +all-features = false + +[lib] +name = "rusqlite" + +[[test]] +name = "config_log" +harness = false + +[[test]] +name = "deny_single_threaded_sqlite_config" + +[[test]] +name = "vtab" + +[[bench]] +name = "cache" +harness = false + +[[bench]] +name = "exec" +harness = false + +[dependencies.bitflags] +version = "1.2" + +[dependencies.chrono] +version = "0.4" +features = ["clock"] +optional = true +default-features = false + +[dependencies.csv] +version = "1.1" +optional = true + +[dependencies.fallible-iterator] +version = "0.2" + +[dependencies.fallible-streaming-iterator] +version = "0.1" + +[dependencies.hashlink] +version = "0.8" + +[dependencies.lazy_static] +version = "1.4" +optional = true + +[dependencies.libsqlite3-sys] +version = "0.25.0" + +[dependencies.serde_json] +version = "1.0" +optional = true + +[dependencies.smallvec] +version = "1.6.1" + +[dependencies.time] +version = "0.3.0" +features = [ + "formatting", + "macros", + "parsing", +] +optional = true + +[dependencies.url] +version = "2.1" +optional = true + +[dependencies.uuid] +version = "1.0" +optional = true + +[dev-dependencies.bencher] +version = "0.1" + +[dev-dependencies.doc-comment] +version = "0.3" + +[dev-dependencies.lazy_static] +version = "1.4" + +[dev-dependencies.regex] +version = "1.5.5" + +[dev-dependencies.tempfile] +version = "3.1.0" + +[dev-dependencies.unicase] +version = "2.6.0" + +[dev-dependencies.uuid] +version = "1.0" +features = ["v4"] + +[features] +array = ["vtab"] +backup = ["libsqlite3-sys/min_sqlite_version_3_6_23"] +blob = ["libsqlite3-sys/min_sqlite_version_3_7_7"] +buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"] +bundled = [ + "libsqlite3-sys/bundled", + "modern_sqlite", +] +bundled-full = [ + "modern-full", + "bundled", +] +bundled-sqlcipher = [ + "libsqlite3-sys/bundled-sqlcipher", + "bundled", +] +bundled-sqlcipher-vendored-openssl = [ + "libsqlite3-sys/bundled-sqlcipher-vendored-openssl", + "bundled-sqlcipher", +] +bundled-windows = ["libsqlite3-sys/bundled-windows"] +collation = [] +column_decltype = [] +csvtab = [ + "csv", + "vtab", +] +extra_check = [] +functions = ["libsqlite3-sys/min_sqlite_version_3_7_7"] +hooks = [] +i128_blob = [] +in_gecko = [ + "modern_sqlite", + "libsqlite3-sys/in_gecko", +] +limits = [] +load_extension = [] +modern-full = [ + "array", + "backup", + "blob", + "modern_sqlite", + "chrono", + "collation", + "column_decltype", + "csvtab", + "extra_check", + "functions", + "hooks", + "i128_blob", + "limits", + "load_extension", + "serde_json", + "series", + "time", + "trace", + "unlock_notify", + "url", + "uuid", + "vtab", + "window", +] +modern_sqlite = ["libsqlite3-sys/bundled_bindings"] +release_memory = ["libsqlite3-sys/min_sqlite_version_3_7_16"] +series = ["vtab"] +session = [ + "libsqlite3-sys/session", + "hooks", +] +sqlcipher = ["libsqlite3-sys/sqlcipher"] +trace = ["libsqlite3-sys/min_sqlite_version_3_6_23"] +unlock_notify = ["libsqlite3-sys/unlock_notify"] +vtab = ["libsqlite3-sys/min_sqlite_version_3_7_7"] +wasm32-wasi-vfs = ["libsqlite3-sys/wasm32-wasi-vfs"] +window = ["functions"] +winsqlite3 = ["libsqlite3-sys/winsqlite3"] +with-asan = ["libsqlite3-sys/with-asan"] + +[badges.appveyor] +repository = "rusqlite/rusqlite" + +[badges.codecov] +repository = "rusqlite/rusqlite" + +[badges.maintenance] +status = "actively-developed" diff --git a/Cargo.toml.orig b/Cargo.toml.orig new file mode 100644 index 0000000..bd81d44 --- /dev/null +++ b/Cargo.toml.orig @@ -0,0 +1,168 @@ +[package] +name = "rusqlite" +# Note: Update version in README.md when you change this. +version = "0.28.0" +authors = ["The rusqlite developers"] +edition = "2018" +description = "Ergonomic wrapper for SQLite" +repository = "https://github.com/rusqlite/rusqlite" +documentation = "http://docs.rs/rusqlite/" +readme = "README.md" +keywords = ["sqlite", "database", "ffi"] +license = "MIT" +categories = ["database"] + +exclude = [ + "/.github/*", + "/.gitattributes", + "/appveyor.yml", + "/Changelog.md", + "/clippy.toml", + "/codecov.yml", +] + +[badges] +appveyor = { repository = "rusqlite/rusqlite" } +codecov = { repository = "rusqlite/rusqlite" } +maintenance = { status = "actively-developed" } + +[lib] +name = "rusqlite" + +[workspace] +members = ["libsqlite3-sys"] + +[features] +load_extension = [] +# hot-backup interface: 3.6.11 (2009-02-18) +backup = ["libsqlite3-sys/min_sqlite_version_3_6_23"] +# sqlite3_blob_reopen: 3.7.4 +blob = ["libsqlite3-sys/min_sqlite_version_3_7_7"] +collation = [] +# sqlite3_create_function_v2: 3.7.3 (2010-10-08) +functions = ["libsqlite3-sys/min_sqlite_version_3_7_7"] +# sqlite3_log: 3.6.23 (2010-03-09) +trace = ["libsqlite3-sys/min_sqlite_version_3_6_23"] +# sqlite3_db_release_memory: 3.7.10 (2012-01-16) +release_memory = ["libsqlite3-sys/min_sqlite_version_3_7_16"] +bundled = ["libsqlite3-sys/bundled", "modern_sqlite"] +bundled-sqlcipher = ["libsqlite3-sys/bundled-sqlcipher", "bundled"] +bundled-sqlcipher-vendored-openssl = ["libsqlite3-sys/bundled-sqlcipher-vendored-openssl", "bundled-sqlcipher"] +buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"] +limits = [] +hooks = [] +i128_blob = [] +sqlcipher = ["libsqlite3-sys/sqlcipher"] +unlock_notify = ["libsqlite3-sys/unlock_notify"] +# xSavepoint, xRelease and xRollbackTo: 3.7.7 (2011-06-23) +vtab = ["libsqlite3-sys/min_sqlite_version_3_7_7"] +csvtab = ["csv", "vtab"] +# pointer passing interfaces: 3.20.0 +array = ["vtab"] +# session extension: 3.13.0 +session = ["libsqlite3-sys/session", "hooks"] +# window functions: 3.25.0 +window = ["functions"] +# 3.9.0 +series = ["vtab"] +# check for invalid query. +extra_check = [] +modern_sqlite = ["libsqlite3-sys/bundled_bindings"] +in_gecko = ["modern_sqlite", "libsqlite3-sys/in_gecko"] +bundled-windows = ["libsqlite3-sys/bundled-windows"] +# Build bundled sqlite with -fsanitize=address +with-asan = ["libsqlite3-sys/with-asan"] +column_decltype = [] +wasm32-wasi-vfs = ["libsqlite3-sys/wasm32-wasi-vfs"] +# Note: doesn't support 32-bit. +winsqlite3 = ["libsqlite3-sys/winsqlite3"] + +# Helper feature for enabling most non-build-related optional features +# or dependencies (except `session`). This is useful for running tests / clippy +# / etc. New features and optional dependencies that don't conflict with anything +# else should be added here. +modern-full = [ + "array", + "backup", + "blob", + "modern_sqlite", + "chrono", + "collation", + "column_decltype", + "csvtab", + "extra_check", + "functions", + "hooks", + "i128_blob", + "limits", + "load_extension", + "serde_json", + "series", + "time", + "trace", + "unlock_notify", + "url", + "uuid", + "vtab", + "window", +] + +bundled-full = ["modern-full", "bundled"] + +[dependencies] +time = { version = "0.3.0", features = ["formatting", "macros", "parsing"], optional = true } +bitflags = "1.2" +hashlink = "0.8" +chrono = { version = "0.4", optional = true, default-features = false, features = ["clock"] } +serde_json = { version = "1.0", optional = true } +csv = { version = "1.1", optional = true } +url = { version = "2.1", optional = true } +lazy_static = { version = "1.4", optional = true } +fallible-iterator = "0.2" +fallible-streaming-iterator = "0.1" +uuid = { version = "1.0", optional = true } +smallvec = "1.6.1" + +[dev-dependencies] +doc-comment = "0.3" +tempfile = "3.1.0" +lazy_static = "1.4" +regex = "1.5.5" +uuid = { version = "1.0", features = ["v4"] } +unicase = "2.6.0" +# Use `bencher` over criterion because it builds much faster and we don't have +# many benchmarks +bencher = "0.1" + +[dependencies.libsqlite3-sys] +path = "libsqlite3-sys" +version = "0.25.0" + +[[test]] +name = "config_log" +harness = false + +[[test]] +name = "deny_single_threaded_sqlite_config" + +[[test]] +name = "vtab" + +[[bench]] +name = "cache" +harness = false + +[[bench]] +name = "exec" +harness = false + +[package.metadata.docs.rs] +features = ["modern-full"] +all-features = false +no-default-features = true +default-target = "x86_64-unknown-linux-gnu" +rustdoc-args = ["--cfg", "docsrs"] + +[package.metadata.playground] +features = ["bundled-full"] +all-features = false diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9e5b9f7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2014-2021 The rusqlite developers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..fdc2381 --- /dev/null +++ b/README.md @@ -0,0 +1,250 @@ +# Rusqlite + +[![Latest Version](https://img.shields.io/crates/v/rusqlite.svg)](https://crates.io/crates/rusqlite) +[![Documentation](https://docs.rs/rusqlite/badge.svg)](https://docs.rs/rusqlite) +[![Build Status (GitHub)](https://github.com/rusqlite/rusqlite/workflows/CI/badge.svg)](https://github.com/rusqlite/rusqlite/actions) +[![Build Status (AppVeyor)](https://ci.appveyor.com/api/projects/status/github/rusqlite/rusqlite?branch=master&svg=true)](https://ci.appveyor.com/project/rusqlite/rusqlite) +[![Code Coverage](https://codecov.io/gh/rusqlite/rusqlite/branch/master/graph/badge.svg)](https://codecov.io/gh/rusqlite/rusqlite) +[![Dependency Status](https://deps.rs/repo/github/rusqlite/rusqlite/status.svg)](https://deps.rs/repo/github/rusqlite/rusqlite) +[![Discord Chat](https://img.shields.io/discord/927966344266256434.svg?logo=discord)](https://discord.gg/nFYfGPB8g4) + +Rusqlite is an ergonomic wrapper for using SQLite from Rust. + +Historically, the API was based on the one from [`rust-postgres`](https://github.com/sfackler/rust-postgres). However, the two have diverged in many ways, and no compatibility between the two is intended. + +## Usage + +In your Cargo.toml: + +```toml +[dependencies] +# `bundled` causes us to automatically compile and link in an up to date +# version of SQLite for you. This avoids many common build issues, and +# avoids depending on the version of SQLite on the users system (or your +# system), which may be old or missing. It's the right choice for most +# programs that control their own SQLite databases. +# +# That said, it's not ideal for all scenarios and in particular, generic +# libraries built around `rusqlite` should probably not enable it, which +# is why it is not a default feature -- it could become hard to disable. +rusqlite = { version = "0.28.0", features = ["bundled"] } +``` + +Simple example usage: + +```rust +use rusqlite::{Connection, Result}; + +#[derive(Debug)] +struct Person { + id: i32, + name: String, + data: Option>, +} + +fn main() -> Result<()> { + let conn = Connection::open_in_memory()?; + + conn.execute( + "CREATE TABLE person ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + data BLOB + )", + (), // empty list of parameters. + )?; + let me = Person { + id: 0, + name: "Steven".to_string(), + data: None, + }; + conn.execute( + "INSERT INTO person (name, data) VALUES (?1, ?2)", + (&me.name, &me.data), + )?; + + let mut stmt = conn.prepare("SELECT id, name, data FROM person")?; + let person_iter = stmt.query_map([], |row| { + Ok(Person { + id: row.get(0)?, + name: row.get(1)?, + data: row.get(2)?, + }) + })?; + + for person in person_iter { + println!("Found person {:?}", person.unwrap()); + } + Ok(()) +} +``` + +### Supported SQLite Versions + +The base `rusqlite` package supports SQLite version 3.6.8 or newer. If you need +support for older versions, please file an issue. Some cargo features require a +newer SQLite version; see details below. + +### Optional Features + +Rusqlite provides several features that are behind [Cargo +features](https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section). They are: + +* [`load_extension`](https://docs.rs/rusqlite/~0/rusqlite/struct.LoadExtensionGuard.html) + allows loading dynamic library-based SQLite extensions. +* [`backup`](https://docs.rs/rusqlite/~0/rusqlite/backup/index.html) + allows use of SQLite's online backup API. Note: This feature requires SQLite 3.6.11 or later. +* [`functions`](https://docs.rs/rusqlite/~0/rusqlite/functions/index.html) + allows you to load Rust closures into SQLite connections for use in queries. + Note: This feature requires SQLite 3.7.3 or later. +* `window` for [window function](https://www.sqlite.org/windowfunctions.html) support (`fun(...) OVER ...`). (Implies `functions`.) +* [`trace`](https://docs.rs/rusqlite/~0/rusqlite/trace/index.html) + allows hooks into SQLite's tracing and profiling APIs. Note: This feature + requires SQLite 3.6.23 or later. +* [`blob`](https://docs.rs/rusqlite/~0/rusqlite/blob/index.html) + gives `std::io::{Read, Write, Seek}` access to SQL BLOBs. Note: This feature + requires SQLite 3.7.4 or later. +* [`limits`](https://docs.rs/rusqlite/~0/rusqlite/struct.Connection.html#method.limit) + allows you to set and retrieve SQLite's per connection limits. +* `chrono` implements [`FromSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.FromSql.html) + and [`ToSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.ToSql.html) for various + types from the [`chrono` crate](https://crates.io/crates/chrono). +* `serde_json` implements [`FromSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.FromSql.html) + and [`ToSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.ToSql.html) for the + `Value` type from the [`serde_json` crate](https://crates.io/crates/serde_json). +* `time` implements [`FromSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.FromSql.html) + and [`ToSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.ToSql.html) for the + `time::OffsetDateTime` type from the [`time` crate](https://crates.io/crates/time). +* `url` implements [`FromSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.FromSql.html) + and [`ToSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.ToSql.html) for the + `Url` type from the [`url` crate](https://crates.io/crates/url). +* `bundled` uses a bundled version of SQLite. This is a good option for cases where linking to SQLite is complicated, such as Windows. +* `sqlcipher` looks for the SQLCipher library to link against instead of SQLite. This feature overrides `bundled`. +* `bundled-sqlcipher` uses a bundled version of SQLCipher. This searches for and links against a system-installed crypto library to provide the crypto implementation. +* `bundled-sqlcipher-vendored-openssl` allows using bundled-sqlcipher with a vendored version of OpenSSL (via the `openssl-sys` crate) as the crypto provider. + - As the name implies this depends on the `bundled-sqlcipher` feature, and automatically turns it on. + - If turned on, this uses the [`openssl-sys`](https://crates.io/crates/openssl-sys) crate, with the `vendored` feature enabled in order to build and bundle the OpenSSL crypto library. +* `hooks` for [Commit, Rollback](http://sqlite.org/c3ref/commit_hook.html) and [Data Change](http://sqlite.org/c3ref/update_hook.html) notification callbacks. +* `unlock_notify` for [Unlock](https://sqlite.org/unlock_notify.html) notification. +* `vtab` for [virtual table](https://sqlite.org/vtab.html) support (allows you to write virtual table implementations in Rust). Currently, only read-only virtual tables are supported. +* `series` exposes [`generate_series(...)`](https://www.sqlite.org/series.html) Table-Valued Function. (Implies `vtab`.) +* [`csvtab`](https://sqlite.org/csv.html), CSV virtual table written in Rust. (Implies `vtab`.) +* [`array`](https://sqlite.org/carray.html), The `rarray()` Table-Valued Function. (Implies `vtab`.) +* `i128_blob` allows storing values of type `i128` type in SQLite databases. Internally, the data is stored as a 16 byte big-endian blob, with the most significant bit flipped, which allows ordering and comparison between different blobs storing i128s to work as expected. +* `uuid` allows storing and retrieving `Uuid` values from the [`uuid`](https://docs.rs/uuid/) crate using blobs. +* [`session`](https://sqlite.org/sessionintro.html), Session module extension. Requires `buildtime_bindgen` feature. (Implies `hooks`.) +* `extra_check` fail when a query passed to execute is readonly or has a column count > 0. +* `column_decltype` provides `columns()` method for Statements and Rows; omit if linking to a version of SQLite/SQLCipher compiled with `-DSQLITE_OMIT_DECLTYPE`. +* `collation` exposes [`sqlite3_create_collation_v2`](https://sqlite.org/c3ref/create_collation.html). +* `winsqlite3` allows linking against the SQLite present in newer versions of Windows + +## Notes on building rusqlite and libsqlite3-sys + +`libsqlite3-sys` is a separate crate from `rusqlite` that provides the Rust +declarations for SQLite's C API. By default, `libsqlite3-sys` attempts to find a SQLite library that already exists on your system using pkg-config, or a +[Vcpkg](https://github.com/Microsoft/vcpkg) installation for MSVC ABI builds. + +You can adjust this behavior in a number of ways: + +* If you use the `bundled`, `bundled-sqlcipher`, or `bundled-sqlcipher-vendored-openssl` features, `libsqlite3-sys` will use the + [cc](https://crates.io/crates/cc) crate to compile SQLite or SQLCipher from source and + link against that. This source is embedded in the `libsqlite3-sys` crate and + is currently SQLite 3.39.0 (as of `rusqlite` 0.28.0 / `libsqlite3-sys` + 0.25.0). This is probably the simplest solution to any build problems. You can enable this by adding the following in your `Cargo.toml` file: + ```toml + [dependencies.rusqlite] + version = "0.28.0" + features = ["bundled"] + ``` +* When using any of the `bundled` features, the build script will honor `SQLITE_MAX_VARIABLE_NUMBER` and `SQLITE_MAX_EXPR_DEPTH` variables. It will also honor a `LIBSQLITE3_FLAGS` variable, which can have a format like `"-USQLITE_ALPHA -DSQLITE_BETA SQLITE_GAMMA ..."`. That would disable the `SQLITE_ALPHA` flag, and set the `SQLITE_BETA` and `SQLITE_GAMMA` flags. (The initial `-D` can be omitted, as on the last one.) +* When using `bundled-sqlcipher` (and not also using `bundled-sqlcipher-vendored-openssl`), `libsqlite3-sys` will need to + link against crypto libraries on the system. If the build script can find a `libcrypto` from OpenSSL or LibreSSL (it will consult `OPENSSL_LIB_DIR`/`OPENSSL_INCLUDE_DIR` and `OPENSSL_DIR` environment variables), it will use that. If building on and for Macs, and none of those variables are set, it will use the system's SecurityFramework instead. + +* When linking against a SQLite (or SQLCipher) library already on the system (so *not* using any of the `bundled` features), you can set the `SQLITE3_LIB_DIR` (or `SQLCIPHER_LIB_DIR`) environment variable to point to a directory containing the library. You can also set the `SQLITE3_INCLUDE_DIR` (or `SQLCIPHER_INCLUDE_DIR`) variable to point to the directory containing `sqlite3.h`. +* Installing the sqlite3 development packages will usually be all that is required, but + the build helpers for [pkg-config](https://github.com/alexcrichton/pkg-config-rs) + and [vcpkg](https://github.com/mcgoo/vcpkg-rs) have some additional configuration + options. The default when using vcpkg is to dynamically link, + which must be enabled by setting `VCPKGRS_DYNAMIC=1` environment variable before build. + `vcpkg install sqlite3:x64-windows` will install the required library. +* When linking against a SQLite (or SQLCipher) library already on the system, you can set the `SQLITE3_STATIC` (or `SQLCIPHER_STATIC`) environment variable to 1 to request that the library be statically instead of dynamically linked. + + +### Binding generation + +We use [bindgen](https://crates.io/crates/bindgen) to generate the Rust +declarations from SQLite's C header file. `bindgen` +[recommends](https://github.com/servo/rust-bindgen#library-usage-with-buildrs) +running this as part of the build process of libraries that used this. We tried +this briefly (`rusqlite` 0.10.0, specifically), but it had some annoyances: + +* The build time for `libsqlite3-sys` (and therefore `rusqlite`) increased + dramatically. +* Running `bindgen` requires a relatively-recent version of Clang, which many + systems do not have installed by default. +* Running `bindgen` also requires the SQLite header file to be present. + +As of `rusqlite` 0.10.1, we avoid running `bindgen` at build-time by shipping +pregenerated bindings for several versions of SQLite. When compiling +`rusqlite`, we use your selected Cargo features to pick the bindings for the +minimum SQLite version that supports your chosen features. If you are using +`libsqlite3-sys` directly, you can use the same features to choose which +pregenerated bindings are chosen: + +* `min_sqlite_version_3_6_8` - SQLite 3.6.8 bindings (this is the default) +* `min_sqlite_version_3_6_23` - SQLite 3.6.23 bindings +* `min_sqlite_version_3_7_7` - SQLite 3.7.7 bindings + +If you use any of the `bundled` features, you will get pregenerated bindings for the +bundled version of SQLite/SQLCipher. If you need other specific pregenerated binding +versions, please file an issue. If you want to run `bindgen` at buildtime to +produce your own bindings, use the `buildtime_bindgen` Cargo feature. + +If you enable the `modern_sqlite` feature, we'll use the bindings we would have +included with the bundled build. You generally should have `buildtime_bindgen` +enabled if you turn this on, as otherwise you'll need to keep the version of +SQLite you link with in sync with what rusqlite would have bundled, (usually the +most recent release of SQLite). Failing to do this will cause a runtime error. + +## Contributing + +Rusqlite has many features, and many of them impact the build configuration in +incompatible ways. This is unfortunate, and makes testing changes hard. + +To help here: you generally should ensure that you run tests/lint for +`--features bundled`, and `--features "bundled-full session buildtime_bindgen"`. + +If running bindgen is problematic for you, `--features bundled-full` enables +bundled and all features which don't require binding generation, and can be used +instead. + +### Checklist + +- Run `cargo fmt` to ensure your Rust code is correctly formatted. +- Ensure `cargo clippy --workspace --features bundled` passes without warnings. +- Ensure `cargo clippy --workspace --features "bundled-full session buildtime_bindgen"` passes without warnings. +- Ensure `cargo test --workspace --features bundled` reports no failures. +- Ensure `cargo test --workspace --features "bundled-full session buildtime_bindgen"` reports no failures. + +## Author + +Rusqlite is the product of hard work by a number of people. A list is available +here: https://github.com/rusqlite/rusqlite/graphs/contributors + +## Community + +Feel free to join the [Rusqlite Discord Server](https://discord.gg/nFYfGPB8g4) to discuss or get help with `rusqlite` or `libsqlite3-sys`. + +## License + +Rusqlite and libsqlite3-sys are available under the MIT license. See the LICENSE file for more info. + +### Licenses of Bundled Software + +Depending on the set of enabled cargo `features`, rusqlite and libsqlite3-sys will also bundle other libraries, which have their own licensing terms: + +- If `--features=bundled-sqlcipher` is enabled, the vendored source of [SQLcipher](https://github.com/sqlcipher/sqlcipher) will be compiled and statically linked in. SQLcipher is distributed under a BSD-style license, as described [here](libsqlite3-sys/sqlcipher/LICENSE). + +- If `--features=bundled` is enabled, the vendored source of SQLite will be compiled and linked in. SQLite is in the public domain, as described [here](https://www.sqlite.org/copyright.html). + +Both of these are quite permissive, have no bearing on the license of the code in `rusqlite` or `libsqlite3-sys` themselves, and can be entirely ignored if you do not use the feature in question. diff --git a/benches/cache.rs b/benches/cache.rs new file mode 100644 index 0000000..dd3683e --- /dev/null +++ b/benches/cache.rs @@ -0,0 +1,18 @@ +use bencher::{benchmark_group, benchmark_main, Bencher}; +use rusqlite::Connection; + +fn bench_no_cache(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + db.set_prepared_statement_cache_capacity(0); + let sql = "SELECT 1, 'test', 3.14 UNION SELECT 2, 'exp', 2.71"; + b.iter(|| db.prepare(sql).unwrap()); +} + +fn bench_cache(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + let sql = "SELECT 1, 'test', 3.14 UNION SELECT 2, 'exp', 2.71"; + b.iter(|| db.prepare_cached(sql).unwrap()); +} + +benchmark_group!(cache_benches, bench_no_cache, bench_cache); +benchmark_main!(cache_benches); diff --git a/benches/exec.rs b/benches/exec.rs new file mode 100644 index 0000000..b95cb35 --- /dev/null +++ b/benches/exec.rs @@ -0,0 +1,17 @@ +use bencher::{benchmark_group, benchmark_main, Bencher}; +use rusqlite::Connection; + +fn bench_execute(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + let sql = "PRAGMA user_version=1"; + b.iter(|| db.execute(sql, []).unwrap()); +} + +fn bench_execute_batch(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + let sql = "PRAGMA user_version=1"; + b.iter(|| db.execute_batch(sql).unwrap()); +} + +benchmark_group!(exec_benches, bench_execute, bench_execute_batch); +benchmark_main!(exec_benches); diff --git a/src/backup.rs b/src/backup.rs new file mode 100644 index 0000000..6da01fd --- /dev/null +++ b/src/backup.rs @@ -0,0 +1,428 @@ +//! Online SQLite backup API. +//! +//! To create a [`Backup`], you must have two distinct [`Connection`]s - one +//! for the source (which can be used while the backup is running) and one for +//! the destination (which cannot). A [`Backup`] handle exposes three methods: +//! [`step`](Backup::step) will attempt to back up a specified number of pages, +//! [`progress`](Backup::progress) gets the current progress of the backup as of +//! the last call to [`step`](Backup::step), and +//! [`run_to_completion`](Backup::run_to_completion) will attempt to back up the +//! entire source database, allowing you to specify how many pages are backed up +//! at a time and how long the thread should sleep between chunks of pages. +//! +//! The following example is equivalent to "Example 2: Online Backup of a +//! Running Database" from [SQLite's Online Backup API +//! documentation](https://www.sqlite.org/backup.html). +//! +//! ```rust,no_run +//! # use rusqlite::{backup, Connection, Result}; +//! # use std::path::Path; +//! # use std::time; +//! +//! fn backup_db>( +//! src: &Connection, +//! dst: P, +//! progress: fn(backup::Progress), +//! ) -> Result<()> { +//! let mut dst = Connection::open(dst)?; +//! let backup = backup::Backup::new(src, &mut dst)?; +//! backup.run_to_completion(5, time::Duration::from_millis(250), Some(progress)) +//! } +//! ``` + +use std::marker::PhantomData; +use std::path::Path; +use std::ptr; + +use std::os::raw::c_int; +use std::thread; +use std::time::Duration; + +use crate::ffi; + +use crate::error::error_from_handle; +use crate::{Connection, DatabaseName, Result}; + +impl Connection { + /// Back up the `name` database to the given + /// destination path. + /// + /// If `progress` is not `None`, it will be called periodically + /// until the backup completes. + /// + /// For more fine-grained control over the backup process (e.g., + /// to sleep periodically during the backup or to back up to an + /// already-open database connection), see the `backup` module. + /// + /// # Failure + /// + /// Will return `Err` if the destination path cannot be opened + /// or if the backup fails. + pub fn backup>( + &self, + name: DatabaseName<'_>, + dst_path: P, + progress: Option, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + let mut dst = Connection::open(dst_path)?; + let backup = Backup::new_with_names(self, name, &mut dst, DatabaseName::Main)?; + + let mut r = More; + while r == More { + r = backup.step(100)?; + if let Some(f) = progress { + f(backup.progress()); + } + } + + match r { + Done => Ok(()), + Busy => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_BUSY) }), + Locked => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_LOCKED) }), + More => unreachable!(), + } + } + + /// Restore the given source path into the + /// `name` database. If `progress` is not `None`, it will be + /// called periodically until the restore completes. + /// + /// For more fine-grained control over the restore process (e.g., + /// to sleep periodically during the restore or to restore from an + /// already-open database connection), see the `backup` module. + /// + /// # Failure + /// + /// Will return `Err` if the destination path cannot be opened + /// or if the restore fails. + pub fn restore, F: Fn(Progress)>( + &mut self, + name: DatabaseName<'_>, + src_path: P, + progress: Option, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + let src = Connection::open(src_path)?; + let restore = Backup::new_with_names(&src, DatabaseName::Main, self, name)?; + + let mut r = More; + let mut busy_count = 0_i32; + 'restore_loop: while r == More || r == Busy { + r = restore.step(100)?; + if let Some(ref f) = progress { + f(restore.progress()); + } + if r == Busy { + busy_count += 1; + if busy_count >= 3 { + break 'restore_loop; + } + thread::sleep(Duration::from_millis(100)); + } + } + + match r { + Done => Ok(()), + Busy => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_BUSY) }), + Locked => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_LOCKED) }), + More => unreachable!(), + } + } +} + +/// Possible successful results of calling +/// [`Backup::step`]. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum StepResult { + /// The backup is complete. + Done, + + /// The step was successful but there are still more pages that need to be + /// backed up. + More, + + /// The step failed because appropriate locks could not be acquired. This is + /// not a fatal error - the step can be retried. + Busy, + + /// The step failed because the source connection was writing to the + /// database. This is not a fatal error - the step can be retried. + Locked, +} + +/// Struct specifying the progress of a backup. The +/// percentage completion can be calculated as `(pagecount - remaining) / +/// pagecount`. The progress of a backup is as of the last call to +/// [`step`](Backup::step) - if the source database is modified after a call to +/// [`step`](Backup::step), the progress value will become outdated and +/// potentially incorrect. +#[derive(Copy, Clone, Debug)] +pub struct Progress { + /// Number of pages in the source database that still need to be backed up. + pub remaining: c_int, + /// Total number of pages in the source database. + pub pagecount: c_int, +} + +/// A handle to an online backup. +pub struct Backup<'a, 'b> { + phantom_from: PhantomData<&'a Connection>, + to: &'b Connection, + b: *mut ffi::sqlite3_backup, +} + +impl Backup<'_, '_> { + /// Attempt to create a new handle that will allow backups from `from` to + /// `to`. Note that `to` is a `&mut` - this is because SQLite forbids any + /// API calls on the destination of a backup while the backup is taking + /// place. + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_init` call returns + /// `NULL`. + #[inline] + pub fn new<'a, 'b>(from: &'a Connection, to: &'b mut Connection) -> Result> { + Backup::new_with_names(from, DatabaseName::Main, to, DatabaseName::Main) + } + + /// Attempt to create a new handle that will allow backups from the + /// `from_name` database of `from` to the `to_name` database of `to`. Note + /// that `to` is a `&mut` - this is because SQLite forbids any API calls on + /// the destination of a backup while the backup is taking place. + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_init` call returns + /// `NULL`. + pub fn new_with_names<'a, 'b>( + from: &'a Connection, + from_name: DatabaseName<'_>, + to: &'b mut Connection, + to_name: DatabaseName<'_>, + ) -> Result> { + let to_name = to_name.as_cstring()?; + let from_name = from_name.as_cstring()?; + + let to_db = to.db.borrow_mut().db; + + let b = unsafe { + let b = ffi::sqlite3_backup_init( + to_db, + to_name.as_ptr(), + from.db.borrow_mut().db, + from_name.as_ptr(), + ); + if b.is_null() { + return Err(error_from_handle(to_db, ffi::sqlite3_errcode(to_db))); + } + b + }; + + Ok(Backup { + phantom_from: PhantomData, + to, + b, + }) + } + + /// Gets the progress of the backup as of the last call to + /// [`step`](Backup::step). + #[inline] + #[must_use] + pub fn progress(&self) -> Progress { + unsafe { + Progress { + remaining: ffi::sqlite3_backup_remaining(self.b), + pagecount: ffi::sqlite3_backup_pagecount(self.b), + } + } + } + + /// Attempts to back up the given number of pages. If `num_pages` is + /// negative, will attempt to back up all remaining pages. This will hold a + /// lock on the source database for the duration, so it is probably not + /// what you want for databases that are currently active (see + /// [`run_to_completion`](Backup::run_to_completion) for a better + /// alternative). + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_step` call returns + /// an error code other than `DONE`, `OK`, `BUSY`, or `LOCKED`. `BUSY` and + /// `LOCKED` are transient errors and are therefore returned as possible + /// `Ok` values. + #[inline] + pub fn step(&self, num_pages: c_int) -> Result { + use self::StepResult::{Busy, Done, Locked, More}; + + let rc = unsafe { ffi::sqlite3_backup_step(self.b, num_pages) }; + match rc { + ffi::SQLITE_DONE => Ok(Done), + ffi::SQLITE_OK => Ok(More), + ffi::SQLITE_BUSY => Ok(Busy), + ffi::SQLITE_LOCKED => Ok(Locked), + _ => self.to.decode_result(rc).map(|_| More), + } + } + + /// Attempts to run the entire backup. Will call + /// [`step(pages_per_step)`](Backup::step) as many times as necessary, + /// sleeping for `pause_between_pages` between each call to give the + /// source database time to process any pending queries. This is a + /// direct implementation of "Example 2: Online Backup of a Running + /// Database" from [SQLite's Online Backup API documentation](https://www.sqlite.org/backup.html). + /// + /// If `progress` is not `None`, it will be called after each step with the + /// current progress of the backup. Note that is possible the progress may + /// not change if the step returns `Busy` or `Locked` even though the + /// backup is still running. + /// + /// # Failure + /// + /// Will return `Err` if any of the calls to [`step`](Backup::step) return + /// `Err`. + pub fn run_to_completion( + &self, + pages_per_step: c_int, + pause_between_pages: Duration, + progress: Option, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + + assert!(pages_per_step > 0, "pages_per_step must be positive"); + + loop { + let r = self.step(pages_per_step)?; + if let Some(progress) = progress { + progress(self.progress()); + } + match r { + More | Busy | Locked => thread::sleep(pause_between_pages), + Done => return Ok(()), + } + } + } +} + +impl Drop for Backup<'_, '_> { + #[inline] + fn drop(&mut self) { + unsafe { ffi::sqlite3_backup_finish(self.b) }; + } +} + +#[cfg(test)] +mod test { + use super::Backup; + use crate::{Connection, DatabaseName, Result}; + use std::time::Duration; + + #[test] + fn test_backup() -> Result<()> { + let src = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + src.execute_batch(sql)?; + + let mut dst = Connection::open_in_memory()?; + + { + let backup = Backup::new(&src, &mut dst)?; + backup.step(-1)?; + } + + let the_answer: i64 = dst.query_row("SELECT x FROM foo", [], |r| r.get(0))?; + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)")?; + + { + let backup = Backup::new(&src, &mut dst)?; + backup.run_to_completion(5, Duration::from_millis(250), None)?; + } + + let the_answer: i64 = dst.query_row("SELECT SUM(x) FROM foo", [], |r| r.get(0))?; + assert_eq!(42 + 43, the_answer); + Ok(()) + } + + #[test] + fn test_backup_temp() -> Result<()> { + let src = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TEMPORARY TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + src.execute_batch(sql)?; + + let mut dst = Connection::open_in_memory()?; + + { + let backup = + Backup::new_with_names(&src, DatabaseName::Temp, &mut dst, DatabaseName::Main)?; + backup.step(-1)?; + } + + let the_answer: i64 = dst.query_row("SELECT x FROM foo", [], |r| r.get(0))?; + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)")?; + + { + let backup = + Backup::new_with_names(&src, DatabaseName::Temp, &mut dst, DatabaseName::Main)?; + backup.run_to_completion(5, Duration::from_millis(250), None)?; + } + + let the_answer: i64 = dst.query_row("SELECT SUM(x) FROM foo", [], |r| r.get(0))?; + assert_eq!(42 + 43, the_answer); + Ok(()) + } + + #[test] + fn test_backup_attached() -> Result<()> { + let src = Connection::open_in_memory()?; + let sql = "ATTACH DATABASE ':memory:' AS my_attached; + BEGIN; + CREATE TABLE my_attached.foo(x INTEGER); + INSERT INTO my_attached.foo VALUES(42); + END;"; + src.execute_batch(sql)?; + + let mut dst = Connection::open_in_memory()?; + + { + let backup = Backup::new_with_names( + &src, + DatabaseName::Attached("my_attached"), + &mut dst, + DatabaseName::Main, + )?; + backup.step(-1)?; + } + + let the_answer: i64 = dst.query_row("SELECT x FROM foo", [], |r| r.get(0))?; + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)")?; + + { + let backup = Backup::new_with_names( + &src, + DatabaseName::Attached("my_attached"), + &mut dst, + DatabaseName::Main, + )?; + backup.run_to_completion(5, Duration::from_millis(250), None)?; + } + + let the_answer: i64 = dst.query_row("SELECT SUM(x) FROM foo", [], |r| r.get(0))?; + assert_eq!(42 + 43, the_answer); + Ok(()) + } +} diff --git a/src/blob/mod.rs b/src/blob/mod.rs new file mode 100644 index 0000000..81c6098 --- /dev/null +++ b/src/blob/mod.rs @@ -0,0 +1,551 @@ +//! Incremental BLOB I/O. +//! +//! Note that SQLite does not provide API-level access to change the size of a +//! BLOB; that must be performed through SQL statements. +//! +//! There are two choices for how to perform IO on a [`Blob`]. +//! +//! 1. The implementations it provides of the `std::io::Read`, `std::io::Write`, +//! and `std::io::Seek` traits. +//! +//! 2. A positional IO API, e.g. [`Blob::read_at`], [`Blob::write_at`] and +//! similar. +//! +//! Documenting these in order: +//! +//! ## 1. `std::io` trait implementations. +//! +//! `Blob` conforms to `std::io::Read`, `std::io::Write`, and `std::io::Seek`, +//! so it plays nicely with other types that build on these (such as +//! `std::io::BufReader` and `std::io::BufWriter`). However, you must be careful +//! with the size of the blob. For example, when using a `BufWriter`, the +//! `BufWriter` will accept more data than the `Blob` will allow, so make sure +//! to call `flush` and check for errors. (See the unit tests in this module for +//! an example.) +//! +//! ## 2. Positional IO +//! +//! `Blob`s also offer a `pread` / `pwrite`-style positional IO api in the form +//! of [`Blob::read_at`], [`Blob::write_at`], [`Blob::raw_read_at`], +//! [`Blob::read_at_exact`], and [`Blob::raw_read_at_exact`]. +//! +//! These APIs all take the position to read from or write to from as a +//! parameter, instead of using an internal `pos` value. +//! +//! ### Positional IO Read Variants +//! +//! For the `read` functions, there are several functions provided: +//! +//! - [`Blob::read_at`] +//! - [`Blob::raw_read_at`] +//! - [`Blob::read_at_exact`] +//! - [`Blob::raw_read_at_exact`] +//! +//! These can be divided along two axes: raw/not raw, and exact/inexact: +//! +//! 1. Raw/not raw refers to the type of the destination buffer. The raw +//! functions take a `&mut [MaybeUninit]` as the destination buffer, +//! where the "normal" functions take a `&mut [u8]`. +//! +//! Using `MaybeUninit` here can be more efficient in some cases, but is +//! often inconvenient, so both are provided. +//! +//! 2. Exact/inexact refers to to whether or not the entire buffer must be +//! filled in order for the call to be considered a success. +//! +//! The "exact" functions require the provided buffer be entirely filled, or +//! they return an error, whereas the "inexact" functions read as much out of +//! the blob as is available, and return how much they were able to read. +//! +//! The inexact functions are preferable if you do not know the size of the +//! blob already, and the exact functions are preferable if you do. +//! +//! ### Comparison to using the `std::io` traits: +//! +//! In general, the positional methods offer the following Pro/Cons compared to +//! using the implementation `std::io::{Read, Write, Seek}` we provide for +//! `Blob`: +//! +//! 1. (Pro) There is no need to first seek to a position in order to perform IO +//! on it as the position is a parameter. +//! +//! 2. (Pro) `Blob`'s positional read functions don't mutate the blob in any +//! way, and take `&self`. No `&mut` access required. +//! +//! 3. (Pro) Positional IO functions return `Err(rusqlite::Error)` on failure, +//! rather than `Err(std::io::Error)`. Returning `rusqlite::Error` is more +//! accurate and convenient. +//! +//! Note that for the `std::io` API, no data is lost however, and it can be +//! recovered with `io_err.downcast::()` (this can be easy +//! to forget, though). +//! +//! 4. (Pro, for now). A `raw` version of the read API exists which can allow +//! reading into a `&mut [MaybeUninit]` buffer, which avoids a potential +//! costly initialization step. (However, `std::io` traits will certainly +//! gain this someday, which is why this is only a "Pro, for now"). +//! +//! 5. (Con) The set of functions is more bare-bones than what is offered in +//! `std::io`, which has a number of adapters, handy algorithms, further +//! traits. +//! +//! 6. (Con) No meaningful interoperability with other crates, so if you need +//! that you must use `std::io`. +//! +//! To generalize: the `std::io` traits are useful because they conform to a +//! standard interface that a lot of code knows how to handle, however that +//! interface is not a perfect fit for [`Blob`], so another small set of +//! functions is provided as well. +//! +//! # Example (`std::io`) +//! +//! ```rust +//! # use rusqlite::blob::ZeroBlob; +//! # use rusqlite::{Connection, DatabaseName}; +//! # use std::error::Error; +//! # use std::io::{Read, Seek, SeekFrom, Write}; +//! # fn main() -> Result<(), Box> { +//! let db = Connection::open_in_memory()?; +//! db.execute_batch("CREATE TABLE test_table (content BLOB);")?; +//! +//! // Insert a BLOB into the `content` column of `test_table`. Note that the Blob +//! // I/O API provides no way of inserting or resizing BLOBs in the DB -- this +//! // must be done via SQL. +//! db.execute("INSERT INTO test_table (content) VALUES (ZEROBLOB(10))", [])?; +//! +//! // Get the row id off the BLOB we just inserted. +//! let rowid = db.last_insert_rowid(); +//! // Open the BLOB we just inserted for IO. +//! let mut blob = db.blob_open(DatabaseName::Main, "test_table", "content", rowid, false)?; +//! +//! // Write some data into the blob. Make sure to test that the number of bytes +//! // written matches what you expect; if you try to write too much, the data +//! // will be truncated to the size of the BLOB. +//! let bytes_written = blob.write(b"01234567")?; +//! assert_eq!(bytes_written, 8); +//! +//! // Move back to the start and read into a local buffer. +//! // Same guidance - make sure you check the number of bytes read! +//! blob.seek(SeekFrom::Start(0))?; +//! let mut buf = [0u8; 20]; +//! let bytes_read = blob.read(&mut buf[..])?; +//! assert_eq!(bytes_read, 10); // note we read 10 bytes because the blob has size 10 +//! +//! // Insert another BLOB, this time using a parameter passed in from +//! // rust (potentially with a dynamic size). +//! db.execute( +//! "INSERT INTO test_table (content) VALUES (?)", +//! [ZeroBlob(64)], +//! )?; +//! +//! // given a new row ID, we can reopen the blob on that row +//! let rowid = db.last_insert_rowid(); +//! blob.reopen(rowid)?; +//! // Just check that the size is right. +//! assert_eq!(blob.len(), 64); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Example (Positional) +//! +//! ```rust +//! # use rusqlite::blob::ZeroBlob; +//! # use rusqlite::{Connection, DatabaseName}; +//! # use std::error::Error; +//! # fn main() -> Result<(), Box> { +//! let db = Connection::open_in_memory()?; +//! db.execute_batch("CREATE TABLE test_table (content BLOB);")?; +//! // Insert a blob into the `content` column of `test_table`. Note that the Blob +//! // I/O API provides no way of inserting or resizing blobs in the DB -- this +//! // must be done via SQL. +//! db.execute("INSERT INTO test_table (content) VALUES (ZEROBLOB(10))", [])?; +//! // Get the row id off the blob we just inserted. +//! let rowid = db.last_insert_rowid(); +//! // Open the blob we just inserted for IO. +//! let mut blob = db.blob_open(DatabaseName::Main, "test_table", "content", rowid, false)?; +//! // Write some data into the blob. +//! blob.write_at(b"ABCDEF", 2)?; +//! +//! // Read the whole blob into a local buffer. +//! let mut buf = [0u8; 10]; +//! blob.read_at_exact(&mut buf, 0)?; +//! assert_eq!(&buf, b"\0\0ABCDEF\0\0"); +//! +//! // Insert another blob, this time using a parameter passed in from +//! // rust (potentially with a dynamic size). +//! db.execute( +//! "INSERT INTO test_table (content) VALUES (?)", +//! [ZeroBlob(64)], +//! )?; +//! +//! // given a new row ID, we can reopen the blob on that row +//! let rowid = db.last_insert_rowid(); +//! blob.reopen(rowid)?; +//! assert_eq!(blob.len(), 64); +//! # Ok(()) +//! # } +//! ``` +use std::cmp::min; +use std::io; +use std::ptr; + +use super::ffi; +use super::types::{ToSql, ToSqlOutput}; +use crate::{Connection, DatabaseName, Result}; + +mod pos_io; + +/// Handle to an open BLOB. See +/// [`rusqlite::blob`](crate::blob) documentation for in-depth discussion. +pub struct Blob<'conn> { + conn: &'conn Connection, + blob: *mut ffi::sqlite3_blob, + // used by std::io implementations, + pos: i32, +} + +impl Connection { + /// Open a handle to the BLOB located in `row_id`, + /// `column`, `table` in database `db`. + /// + /// # Failure + /// + /// Will return `Err` if `db`/`table`/`column` cannot be converted to a + /// C-compatible string or if the underlying SQLite BLOB open call + /// fails. + #[inline] + pub fn blob_open<'a>( + &'a self, + db: DatabaseName<'_>, + table: &str, + column: &str, + row_id: i64, + read_only: bool, + ) -> Result> { + let c = self.db.borrow_mut(); + let mut blob = ptr::null_mut(); + let db = db.as_cstring()?; + let table = super::str_to_cstring(table)?; + let column = super::str_to_cstring(column)?; + let rc = unsafe { + ffi::sqlite3_blob_open( + c.db(), + db.as_ptr(), + table.as_ptr(), + column.as_ptr(), + row_id, + if read_only { 0 } else { 1 }, + &mut blob, + ) + }; + c.decode_result(rc).map(|_| Blob { + conn: self, + blob, + pos: 0, + }) + } +} + +impl Blob<'_> { + /// Move a BLOB handle to a new row. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite BLOB reopen call fails. + #[inline] + pub fn reopen(&mut self, row: i64) -> Result<()> { + let rc = unsafe { ffi::sqlite3_blob_reopen(self.blob, row) }; + if rc != ffi::SQLITE_OK { + return self.conn.decode_result(rc); + } + self.pos = 0; + Ok(()) + } + + /// Return the size in bytes of the BLOB. + #[inline] + #[must_use] + pub fn size(&self) -> i32 { + unsafe { ffi::sqlite3_blob_bytes(self.blob) } + } + + /// Return the current size in bytes of the BLOB. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + use std::convert::TryInto; + self.size().try_into().unwrap() + } + + /// Return true if the BLOB is empty. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.size() == 0 + } + + /// Close a BLOB handle. + /// + /// Calling `close` explicitly is not required (the BLOB will be closed + /// when the `Blob` is dropped), but it is available so you can get any + /// errors that occur. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite close call fails. + #[inline] + pub fn close(mut self) -> Result<()> { + self.close_() + } + + #[inline] + fn close_(&mut self) -> Result<()> { + let rc = unsafe { ffi::sqlite3_blob_close(self.blob) }; + self.blob = ptr::null_mut(); + self.conn.decode_result(rc) + } +} + +impl io::Read for Blob<'_> { + /// Read data from a BLOB incrementally. Will return Ok(0) if the end of + /// the blob has been reached. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite read call fails. + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let max_allowed_len = (self.size() - self.pos) as usize; + let n = min(buf.len(), max_allowed_len) as i32; + if n <= 0 { + return Ok(0); + } + let rc = unsafe { ffi::sqlite3_blob_read(self.blob, buf.as_mut_ptr().cast(), n, self.pos) }; + self.conn + .decode_result(rc) + .map(|_| { + self.pos += n; + n as usize + }) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + } +} + +impl io::Write for Blob<'_> { + /// Write data into a BLOB incrementally. Will return `Ok(0)` if the end of + /// the blob has been reached; consider using `Write::write_all(buf)` + /// if you want to get an error if the entirety of the buffer cannot be + /// written. + /// + /// This function may only modify the contents of the BLOB; it is not + /// possible to increase the size of a BLOB using this API. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite write call fails. + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + let max_allowed_len = (self.size() - self.pos) as usize; + let n = min(buf.len(), max_allowed_len) as i32; + if n <= 0 { + return Ok(0); + } + let rc = unsafe { ffi::sqlite3_blob_write(self.blob, buf.as_ptr() as *mut _, n, self.pos) }; + self.conn + .decode_result(rc) + .map(|_| { + self.pos += n; + n as usize + }) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl io::Seek for Blob<'_> { + /// Seek to an offset, in bytes, in BLOB. + #[inline] + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + let pos = match pos { + io::SeekFrom::Start(offset) => offset as i64, + io::SeekFrom::Current(offset) => i64::from(self.pos) + offset, + io::SeekFrom::End(offset) => i64::from(self.size()) + offset, + }; + + if pos < 0 { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid seek to negative position", + )) + } else if pos > i64::from(self.size()) { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid seek to position past end of blob", + )) + } else { + self.pos = pos as i32; + Ok(pos as u64) + } + } +} + +#[allow(unused_must_use)] +impl Drop for Blob<'_> { + #[inline] + fn drop(&mut self) { + self.close_(); + } +} + +/// BLOB of length N that is filled with zeroes. +/// +/// Zeroblobs are intended to serve as placeholders for BLOBs whose content is +/// later written using incremental BLOB I/O routines. +/// +/// A negative value for the zeroblob results in a zero-length BLOB. +#[derive(Copy, Clone)] +pub struct ZeroBlob(pub i32); + +impl ToSql for ZeroBlob { + #[inline] + fn to_sql(&self) -> Result> { + let ZeroBlob(length) = *self; + Ok(ToSqlOutput::ZeroBlob(length)) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, DatabaseName, Result}; + use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write}; + + fn db_with_test_blob() -> Result<(Connection, i64)> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE test (content BLOB); + INSERT INTO test VALUES (ZEROBLOB(10)); + END;"; + db.execute_batch(sql)?; + let rowid = db.last_insert_rowid(); + Ok((db, rowid)) + } + + #[test] + fn test_blob() -> Result<()> { + let (db, rowid) = db_with_test_blob()?; + + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + assert_eq!(4, blob.write(b"Clob").unwrap()); + assert_eq!(6, blob.write(b"567890xxxxxx").unwrap()); // cannot write past 10 + assert_eq!(0, blob.write(b"5678").unwrap()); // still cannot write past 10 + + blob.reopen(rowid)?; + blob.close()?; + + blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, true)?; + let mut bytes = [0u8; 5]; + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"Clob5"); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"67890"); + assert_eq!(0, blob.read(&mut bytes[..]).unwrap()); + + blob.seek(SeekFrom::Start(2)).unwrap(); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"ob567"); + + // only first 4 bytes of `bytes` should be read into + blob.seek(SeekFrom::Current(-1)).unwrap(); + assert_eq!(4, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"78907"); + + blob.seek(SeekFrom::End(-6)).unwrap(); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"56789"); + + blob.reopen(rowid)?; + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"Clob5"); + + // should not be able to seek negative or past end + assert!(blob.seek(SeekFrom::Current(-20)).is_err()); + assert!(blob.seek(SeekFrom::End(0)).is_ok()); + assert!(blob.seek(SeekFrom::Current(1)).is_err()); + + // write_all should detect when we return Ok(0) because there is no space left, + // and return a write error + blob.reopen(rowid)?; + assert!(blob.write_all(b"0123456789x").is_err()); + Ok(()) + } + + #[test] + fn test_blob_in_bufreader() -> Result<()> { + let (db, rowid) = db_with_test_blob()?; + + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + assert_eq!(8, blob.write(b"one\ntwo\n").unwrap()); + + blob.reopen(rowid)?; + let mut reader = BufReader::new(blob); + + let mut line = String::new(); + assert_eq!(4, reader.read_line(&mut line).unwrap()); + assert_eq!("one\n", line); + + line.truncate(0); + assert_eq!(4, reader.read_line(&mut line).unwrap()); + assert_eq!("two\n", line); + + line.truncate(0); + assert_eq!(2, reader.read_line(&mut line).unwrap()); + assert_eq!("\0\0", line); + Ok(()) + } + + #[test] + fn test_blob_in_bufwriter() -> Result<()> { + let (db, rowid) = db_with_test_blob()?; + + { + let blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + let mut writer = BufWriter::new(blob); + + // trying to write too much and then flush should fail + assert_eq!(8, writer.write(b"01234567").unwrap()); + assert_eq!(8, writer.write(b"01234567").unwrap()); + assert!(writer.flush().is_err()); + } + + { + // ... but it should've written the first 10 bytes + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + let mut bytes = [0u8; 10]; + assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(b"0123456701", &bytes); + } + + { + let blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + let mut writer = BufWriter::new(blob); + + // trying to write_all too much should fail + writer.write_all(b"aaaaaaaaaabbbbb").unwrap(); + assert!(writer.flush().is_err()); + } + + { + // ... but it should've written the first 10 bytes + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + let mut bytes = [0u8; 10]; + assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(b"aaaaaaaaaa", &bytes); + Ok(()) + } + } +} diff --git a/src/blob/pos_io.rs b/src/blob/pos_io.rs new file mode 100644 index 0000000..ecc7d65 --- /dev/null +++ b/src/blob/pos_io.rs @@ -0,0 +1,274 @@ +use super::Blob; + +use std::convert::TryFrom; +use std::mem::MaybeUninit; +use std::slice::from_raw_parts_mut; + +use crate::ffi; +use crate::{Error, Result}; + +impl<'conn> Blob<'conn> { + /// Write `buf` to `self` starting at `write_start`, returning an error if + /// `write_start + buf.len()` is past the end of the blob. + /// + /// If an error is returned, no data is written. + /// + /// Note: the blob cannot be resized using this function -- that must be + /// done using SQL (for example, an `UPDATE` statement). + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position write to, instead of using the internal position that can be + /// manipulated by the `std::io` traits. + /// + /// Unlike the similarly named [`FileExt::write_at`][fext_write_at] function + /// (from `std::os::unix`), it's always an error to perform a "short write". + /// + /// [fext_write_at]: https://doc.rust-lang.org/std/os/unix/fs/trait.FileExt.html#tymethod.write_at + #[inline] + pub fn write_at(&mut self, buf: &[u8], write_start: usize) -> Result<()> { + let len = self.len(); + + if buf.len().saturating_add(write_start) > len { + return Err(Error::BlobSizeError); + } + // We know `len` fits in an `i32`, so either: + // + // 1. `buf.len() + write_start` overflows, in which case we'd hit the + // return above (courtesy of `saturating_add`). + // + // 2. `buf.len() + write_start` doesn't overflow but is larger than len, + // in which case ditto. + // + // 3. `buf.len() + write_start` doesn't overflow but is less than len. + // This means that both `buf.len()` and `write_start` can also be + // losslessly converted to i32, since `len` came from an i32. + // Sanity check the above. + debug_assert!(i32::try_from(write_start).is_ok() && i32::try_from(buf.len()).is_ok()); + self.conn.decode_result(unsafe { + ffi::sqlite3_blob_write( + self.blob, + buf.as_ptr().cast(), + buf.len() as i32, + write_start as i32, + ) + }) + } + + /// An alias for `write_at` provided for compatibility with the conceptually + /// equivalent [`std::os::unix::FileExt::write_all_at`][write_all_at] + /// function from libstd: + /// + /// [write_all_at]: https://doc.rust-lang.org/std/os/unix/fs/trait.FileExt.html#method.write_all_at + #[inline] + pub fn write_all_at(&mut self, buf: &[u8], write_start: usize) -> Result<()> { + self.write_at(buf, write_start) + } + + /// Read as much as possible from `offset` to `offset + buf.len()` out of + /// `self`, writing into `buf`. On success, returns the number of bytes + /// written. + /// + /// If there's insufficient data in `self`, then the returned value will be + /// less than `buf.len()`. + /// + /// See also [`Blob::raw_read_at`], which can take an uninitialized buffer, + /// or [`Blob::read_at_exact`] which returns an error if the entire `buf` is + /// not read. + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position to read from, instead of using the internal position that can + /// be manipulated by the `std::io` traits. Consequently, it does not change + /// that value either. + #[inline] + pub fn read_at(&self, buf: &mut [u8], read_start: usize) -> Result { + // Safety: this is safe because `raw_read_at` never stores uninitialized + // data into `as_uninit`. + let as_uninit: &mut [MaybeUninit] = + unsafe { from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) }; + self.raw_read_at(as_uninit, read_start).map(|s| s.len()) + } + + /// Read as much as possible from `offset` to `offset + buf.len()` out of + /// `self`, writing into `buf`. On success, returns the portion of `buf` + /// which was initialized by this call. + /// + /// If there's insufficient data in `self`, then the returned value will be + /// shorter than `buf`. + /// + /// See also [`Blob::read_at`], which takes a `&mut [u8]` buffer instead of + /// a slice of `MaybeUninit`. + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position to read from, instead of using the internal position that can + /// be manipulated by the `std::io` traits. Consequently, it does not change + /// that value either. + #[inline] + pub fn raw_read_at<'a>( + &self, + buf: &'a mut [MaybeUninit], + read_start: usize, + ) -> Result<&'a mut [u8]> { + let len = self.len(); + + let read_len = match len.checked_sub(read_start) { + None | Some(0) => 0, + Some(v) => v.min(buf.len()), + }; + + if read_len == 0 { + // We could return `Ok(&mut [])`, but it seems confusing that the + // pointers don't match, so fabricate a empty slice of u8 with the + // same base pointer as `buf`. + let empty = unsafe { from_raw_parts_mut(buf.as_mut_ptr().cast::(), 0) }; + return Ok(empty); + } + + // At this point we believe `read_start as i32` is lossless because: + // + // 1. `len as i32` is known to be lossless, since it comes from a SQLite + // api returning an i32. + // + // 2. If we got here, `len.checked_sub(read_start)` was Some (or else + // we'd have hit the `if read_len == 0` early return), so `len` must + // be larger than `read_start`, and so it must fit in i32 as well. + debug_assert!(i32::try_from(read_start).is_ok()); + + // We also believe that `read_start + read_len <= len` because: + // + // 1. This is equivalent to `read_len <= len - read_start` via algebra. + // 2. We know that `read_len` is `min(len - read_start, buf.len())` + // 3. Expanding, this is `min(len - read_start, buf.len()) <= len - read_start`, + // or `min(A, B) <= A` which is clearly true. + // + // Note that this stuff is in debug_assert so no need to use checked_add + // and such -- we'll always panic on overflow in debug builds. + debug_assert!(read_start + read_len <= len); + + // These follow naturally. + debug_assert!(buf.len() >= read_len); + debug_assert!(i32::try_from(buf.len()).is_ok()); + debug_assert!(i32::try_from(read_len).is_ok()); + + unsafe { + self.conn.decode_result(ffi::sqlite3_blob_read( + self.blob, + buf.as_mut_ptr().cast(), + read_len as i32, + read_start as i32, + ))?; + + Ok(from_raw_parts_mut(buf.as_mut_ptr().cast::(), read_len)) + } + } + + /// Equivalent to [`Blob::read_at`], but returns a `BlobSizeError` if `buf` + /// is not fully initialized. + #[inline] + pub fn read_at_exact(&self, buf: &mut [u8], read_start: usize) -> Result<()> { + let n = self.read_at(buf, read_start)?; + if n != buf.len() { + Err(Error::BlobSizeError) + } else { + Ok(()) + } + } + + /// Equivalent to [`Blob::raw_read_at`], but returns a `BlobSizeError` if + /// `buf` is not fully initialized. + #[inline] + pub fn raw_read_at_exact<'a>( + &self, + buf: &'a mut [MaybeUninit], + read_start: usize, + ) -> Result<&'a mut [u8]> { + let buflen = buf.len(); + let initted = self.raw_read_at(buf, read_start)?; + if initted.len() != buflen { + Err(Error::BlobSizeError) + } else { + Ok(initted) + } + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, DatabaseName, Result}; + // to ensure we don't modify seek pos + use std::io::Seek as _; + + #[test] + fn test_pos_io() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE test_table(content BLOB);")?; + db.execute("INSERT INTO test_table(content) VALUES (ZEROBLOB(10))", [])?; + + let rowid = db.last_insert_rowid(); + let mut blob = db.blob_open(DatabaseName::Main, "test_table", "content", rowid, false)?; + // modify the seek pos to ensure we aren't using it or modifying it. + blob.seek(std::io::SeekFrom::Start(1)).unwrap(); + + let one2ten: [u8; 10] = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + blob.write_at(&one2ten, 0).unwrap(); + + let mut s = [0u8; 10]; + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &one2ten, "write should go through"); + assert!(blob.read_at_exact(&mut s, 1).is_err()); + + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &one2ten, "should be unchanged"); + + let mut fives = [0u8; 5]; + blob.read_at_exact(&mut fives, 0).unwrap(); + assert_eq!(&fives, &[1u8, 2, 3, 4, 5]); + + blob.read_at_exact(&mut fives, 5).unwrap(); + assert_eq!(&fives, &[6u8, 7, 8, 9, 10]); + assert!(blob.read_at_exact(&mut fives, 7).is_err()); + assert!(blob.read_at_exact(&mut fives, 12).is_err()); + assert!(blob.read_at_exact(&mut fives, 10).is_err()); + assert!(blob.read_at_exact(&mut fives, i32::MAX as usize).is_err()); + assert!(blob + .read_at_exact(&mut fives, i32::MAX as usize + 1) + .is_err()); + + // zero length writes are fine if in bounds + blob.read_at_exact(&mut [], 10).unwrap(); + blob.read_at_exact(&mut [], 0).unwrap(); + blob.read_at_exact(&mut [], 5).unwrap(); + + blob.write_all_at(&[16, 17, 18, 19, 20], 5).unwrap(); + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &[1u8, 2, 3, 4, 5, 16, 17, 18, 19, 20]); + + assert!(blob.write_at(&[100, 99, 98, 97, 96], 6).is_err()); + assert!(blob + .write_at(&[100, 99, 98, 97, 96], i32::MAX as usize) + .is_err()); + assert!(blob + .write_at(&[100, 99, 98, 97, 96], i32::MAX as usize + 1) + .is_err()); + + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &[1u8, 2, 3, 4, 5, 16, 17, 18, 19, 20]); + + let mut s2: [std::mem::MaybeUninit; 10] = [std::mem::MaybeUninit::uninit(); 10]; + { + let read = blob.raw_read_at_exact(&mut s2, 0).unwrap(); + assert_eq!(read, &s); + assert!(std::ptr::eq(read.as_ptr(), s2.as_ptr().cast())); + } + + let mut empty = []; + assert!(std::ptr::eq( + blob.raw_read_at_exact(&mut empty, 0).unwrap().as_ptr(), + empty.as_ptr().cast(), + )); + assert!(blob.raw_read_at_exact(&mut s2, 5).is_err()); + + let end_pos = blob.seek(std::io::SeekFrom::Current(0)).unwrap(); + assert_eq!(end_pos, 1); + Ok(()) + } +} diff --git a/src/busy.rs b/src/busy.rs new file mode 100644 index 0000000..7297f20 --- /dev/null +++ b/src/busy.rs @@ -0,0 +1,174 @@ +///! Busy handler (when the database is locked) +use std::convert::TryInto; +use std::mem; +use std::os::raw::{c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; +use std::time::Duration; + +use crate::ffi; +use crate::{Connection, InnerConnection, Result}; + +impl Connection { + /// Set a busy handler that sleeps for a specified amount of time when a + /// table is locked. The handler will sleep multiple times until at + /// least "ms" milliseconds of sleeping have accumulated. + /// + /// Calling this routine with an argument equal to zero turns off all busy + /// handlers. + /// + /// There can only be a single busy handler for a particular database + /// connection at any given moment. If another busy handler was defined + /// (using [`busy_handler`](Connection::busy_handler)) prior to calling this + /// routine, that other busy handler is cleared. + /// + /// Newly created connections currently have a default busy timeout of + /// 5000ms, but this may be subject to change. + pub fn busy_timeout(&self, timeout: Duration) -> Result<()> { + let ms: i32 = timeout + .as_secs() + .checked_mul(1000) + .and_then(|t| t.checked_add(timeout.subsec_millis().into())) + .and_then(|t| t.try_into().ok()) + .expect("too big"); + self.db.borrow_mut().busy_timeout(ms) + } + + /// Register a callback to handle `SQLITE_BUSY` errors. + /// + /// If the busy callback is `None`, then `SQLITE_BUSY` is returned + /// immediately upon encountering the lock. The argument to the busy + /// handler callback is the number of times that the + /// busy handler has been invoked previously for the + /// same locking event. If the busy callback returns `false`, then no + /// additional attempts are made to access the + /// database and `SQLITE_BUSY` is returned to the + /// application. If the callback returns `true`, then another attempt + /// is made to access the database and the cycle repeats. + /// + /// There can only be a single busy handler defined for each database + /// connection. Setting a new busy handler clears any previously set + /// handler. Note that calling [`busy_timeout()`](Connection::busy_timeout) + /// or evaluating `PRAGMA busy_timeout=N` will change the busy handler + /// and thus clear any previously set busy handler. + /// + /// Newly created connections default to a + /// [`busy_timeout()`](Connection::busy_timeout) handler with a timeout + /// of 5000ms, although this is subject to change. + pub fn busy_handler(&self, callback: Option bool>) -> Result<()> { + unsafe extern "C" fn busy_handler_callback(p_arg: *mut c_void, count: c_int) -> c_int { + let handler_fn: fn(i32) -> bool = mem::transmute(p_arg); + if let Ok(true) = catch_unwind(|| handler_fn(count)) { + 1 + } else { + 0 + } + } + let c = self.db.borrow_mut(); + let r = match callback { + Some(f) => unsafe { + ffi::sqlite3_busy_handler(c.db(), Some(busy_handler_callback), f as *mut c_void) + }, + None => unsafe { ffi::sqlite3_busy_handler(c.db(), None, ptr::null_mut()) }, + }; + c.decode_result(r) + } +} + +impl InnerConnection { + #[inline] + fn busy_timeout(&mut self, timeout: c_int) -> Result<()> { + let r = unsafe { ffi::sqlite3_busy_timeout(self.db, timeout) }; + self.decode_result(r) + } +} + +#[cfg(test)] +mod test { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::mpsc::sync_channel; + use std::thread; + use std::time::Duration; + + use crate::{Connection, ErrorCode, Result, TransactionBehavior}; + + #[test] + fn test_default_busy() -> Result<()> { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let mut db1 = Connection::open(&path)?; + let tx1 = db1.transaction_with_behavior(TransactionBehavior::Exclusive)?; + let db2 = Connection::open(&path)?; + let r: Result<()> = db2.query_row("PRAGMA schema_version", [], |_| unreachable!()); + assert_eq!( + r.unwrap_err().sqlite_error_code(), + Some(ErrorCode::DatabaseBusy) + ); + tx1.rollback() + } + + #[test] + #[ignore] // FIXME: unstable + fn test_busy_timeout() { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let db2 = Connection::open(&path).unwrap(); + db2.busy_timeout(Duration::from_secs(1)).unwrap(); + + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db1 = Connection::open(&path).unwrap(); + let tx1 = db1 + .transaction_with_behavior(TransactionBehavior::Exclusive) + .unwrap(); + rx.send(1).unwrap(); + thread::sleep(Duration::from_millis(100)); + tx1.rollback().unwrap(); + }); + + assert_eq!(tx.recv().unwrap(), 1); + let _ = db2 + .query_row("PRAGMA schema_version", [], |row| row.get::<_, i32>(0)) + .expect("unexpected error"); + + child.join().unwrap(); + } + + #[test] + #[ignore] // FIXME: unstable + fn test_busy_handler() { + static CALLED: AtomicBool = AtomicBool::new(false); + fn busy_handler(_: i32) -> bool { + CALLED.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(100)); + true + } + + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let db2 = Connection::open(&path).unwrap(); + db2.busy_handler(Some(busy_handler)).unwrap(); + + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db1 = Connection::open(&path).unwrap(); + let tx1 = db1 + .transaction_with_behavior(TransactionBehavior::Exclusive) + .unwrap(); + rx.send(1).unwrap(); + thread::sleep(Duration::from_millis(100)); + tx1.rollback().unwrap(); + }); + + assert_eq!(tx.recv().unwrap(), 1); + let _ = db2 + .query_row("PRAGMA schema_version", [], |row| row.get::<_, i32>(0)) + .expect("unexpected error"); + assert!(CALLED.load(Ordering::Relaxed)); + + child.join().unwrap(); + } +} diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..c80a708 --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,350 @@ +//! Prepared statements cache for faster execution. + +use crate::raw_statement::RawStatement; +use crate::{Connection, Result, Statement}; +use hashlink::LruCache; +use std::cell::RefCell; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +impl Connection { + /// Prepare a SQL statement for execution, returning a previously prepared + /// (but not currently in-use) statement if one is available. The + /// returned statement will be cached for reuse by future calls to + /// [`prepare_cached`](Connection::prepare_cached) once it is dropped. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert_new_people(conn: &Connection) -> Result<()> { + /// { + /// let mut stmt = conn.prepare_cached("INSERT INTO People (name) VALUES (?)")?; + /// stmt.execute(["Joe Smith"])?; + /// } + /// { + /// // This will return the same underlying SQLite statement handle without + /// // having to prepare it again. + /// let mut stmt = conn.prepare_cached("INSERT INTO People (name) VALUES (?)")?; + /// stmt.execute(["Bob Jones"])?; + /// } + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn prepare_cached(&self, sql: &str) -> Result> { + self.cache.get(self, sql) + } + + /// Set the maximum number of cached prepared statements this connection + /// will hold. By default, a connection will hold a relatively small + /// number of cached statements. If you need more, or know that you + /// will not use cached statements, you + /// can set the capacity manually using this method. + #[inline] + pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) { + self.cache.set_capacity(capacity); + } + + /// Remove/finalize all prepared statements currently in the cache. + #[inline] + pub fn flush_prepared_statement_cache(&self) { + self.cache.flush(); + } +} + +/// Prepared statements LRU cache. +// #[derive(Debug)] // FIXME: https://github.com/kyren/hashlink/pull/4 +pub struct StatementCache(RefCell, RawStatement>>); + +#[allow(clippy::non_send_fields_in_send_ty)] +unsafe impl Send for StatementCache {} + +/// Cacheable statement. +/// +/// Statement will return automatically to the cache by default. +/// If you want the statement to be discarded, call +/// [`discard()`](CachedStatement::discard) on it. +pub struct CachedStatement<'conn> { + stmt: Option>, + cache: &'conn StatementCache, +} + +impl<'conn> Deref for CachedStatement<'conn> { + type Target = Statement<'conn>; + + #[inline] + fn deref(&self) -> &Statement<'conn> { + self.stmt.as_ref().unwrap() + } +} + +impl<'conn> DerefMut for CachedStatement<'conn> { + #[inline] + fn deref_mut(&mut self) -> &mut Statement<'conn> { + self.stmt.as_mut().unwrap() + } +} + +impl Drop for CachedStatement<'_> { + #[allow(unused_must_use)] + #[inline] + fn drop(&mut self) { + if let Some(stmt) = self.stmt.take() { + self.cache.cache_stmt(unsafe { stmt.into_raw() }); + } + } +} + +impl CachedStatement<'_> { + #[inline] + fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> { + CachedStatement { + stmt: Some(stmt), + cache, + } + } + + /// Discard the statement, preventing it from being returned to its + /// [`Connection`]'s collection of cached statements. + #[inline] + pub fn discard(mut self) { + self.stmt = None; + } +} + +impl StatementCache { + /// Create a statement cache. + #[inline] + pub fn with_capacity(capacity: usize) -> StatementCache { + StatementCache(RefCell::new(LruCache::new(capacity))) + } + + #[inline] + fn set_capacity(&self, capacity: usize) { + self.0.borrow_mut().set_capacity(capacity); + } + + // Search the cache for a prepared-statement object that implements `sql`. + // If no such prepared-statement can be found, allocate and prepare a new one. + // + // # Failure + // + // Will return `Err` if no cached statement can be found and the underlying + // SQLite prepare call fails. + fn get<'conn>( + &'conn self, + conn: &'conn Connection, + sql: &str, + ) -> Result> { + let trimmed = sql.trim(); + let mut cache = self.0.borrow_mut(); + let stmt = match cache.remove(trimmed) { + Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)), + None => conn.prepare(trimmed), + }; + stmt.map(|mut stmt| { + stmt.stmt.set_statement_cache_key(trimmed); + CachedStatement::new(stmt, self) + }) + } + + // Return a statement to the cache. + fn cache_stmt(&self, stmt: RawStatement) { + if stmt.is_null() { + return; + } + let mut cache = self.0.borrow_mut(); + stmt.clear_bindings(); + if let Some(sql) = stmt.statement_cache_key() { + cache.insert(sql, stmt); + } else { + debug_assert!( + false, + "bug in statement cache code, statement returned to cache that without key" + ); + } + } + + #[inline] + fn flush(&self) { + let mut cache = self.0.borrow_mut(); + cache.clear(); + } +} + +#[cfg(test)] +mod test { + use super::StatementCache; + use crate::{Connection, Result}; + use fallible_iterator::FallibleIterator; + + impl StatementCache { + fn clear(&self) { + self.0.borrow_mut().clear(); + } + + fn len(&self) -> usize { + self.0.borrow().len() + } + + fn capacity(&self) -> usize { + self.0.borrow().capacity() + } + } + + #[test] + fn test_cache() -> Result<()> { + let db = Connection::open_in_memory()?; + let cache = &db.cache; + let initial_capacity = cache.capacity(); + assert_eq!(0, cache.len()); + assert!(initial_capacity > 0); + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + + cache.clear(); + assert_eq!(0, cache.len()); + assert_eq!(initial_capacity, cache.capacity()); + Ok(()) + } + + #[test] + fn test_set_capacity() -> Result<()> { + let db = Connection::open_in_memory()?; + let cache = &db.cache; + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + + db.set_prepared_statement_cache_capacity(0); + assert_eq!(0, cache.len()); + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(0, cache.len()); + + db.set_prepared_statement_cache_capacity(8); + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + Ok(()) + } + + #[test] + fn test_discard() -> Result<()> { + let db = Connection::open_in_memory()?; + let cache = &db.cache; + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + stmt.discard(); + } + assert_eq!(0, cache.len()); + Ok(()) + } + + #[test] + fn test_ddl() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch( + r#" + CREATE TABLE foo (x INT); + INSERT INTO foo VALUES (1); + "#, + )?; + + let sql = "SELECT * FROM foo"; + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(Ok(Some(1i32)), stmt.query([])?.map(|r| r.get(0)).next()); + } + + db.execute_batch( + r#" + ALTER TABLE foo ADD COLUMN y INT; + UPDATE foo SET y = 2; + "#, + )?; + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!( + Ok(Some((1i32, 2i32))), + stmt.query([])?.map(|r| Ok((r.get(0)?, r.get(1)?))).next() + ); + } + Ok(()) + } + + #[test] + fn test_connection_close() -> Result<()> { + let conn = Connection::open_in_memory()?; + conn.prepare_cached("SELECT * FROM sqlite_master;")?; + + conn.close().expect("connection not closed"); + Ok(()) + } + + #[test] + fn test_cache_key() -> Result<()> { + let db = Connection::open_in_memory()?; + let cache = &db.cache; + assert_eq!(0, cache.len()); + + //let sql = " PRAGMA schema_version; -- comment"; + let sql = "PRAGMA schema_version; "; + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + Ok(()) + } + + #[test] + fn test_empty_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + conn.prepare_cached("")?; + Ok(()) + } +} diff --git a/src/collation.rs b/src/collation.rs new file mode 100644 index 0000000..c1fe3f7 --- /dev/null +++ b/src/collation.rs @@ -0,0 +1,215 @@ +//! Add, remove, or modify a collation +use std::cmp::Ordering; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::{catch_unwind, UnwindSafe}; +use std::ptr; +use std::slice; + +use crate::ffi; +use crate::{str_to_cstring, Connection, InnerConnection, Result}; + +// FIXME copy/paste from function.rs +unsafe extern "C" fn free_boxed_value(p: *mut c_void) { + drop(Box::from_raw(p.cast::())); +} + +impl Connection { + /// Add or modify a collation. + #[inline] + pub fn create_collation(&self, collation_name: &str, x_compare: C) -> Result<()> + where + C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, + { + self.db + .borrow_mut() + .create_collation(collation_name, x_compare) + } + + /// Collation needed callback + #[inline] + pub fn collation_needed( + &self, + x_coll_needed: fn(&Connection, &str) -> Result<()>, + ) -> Result<()> { + self.db.borrow_mut().collation_needed(x_coll_needed) + } + + /// Remove collation. + #[inline] + pub fn remove_collation(&self, collation_name: &str) -> Result<()> { + self.db.borrow_mut().remove_collation(collation_name) + } +} + +impl InnerConnection { + fn create_collation(&mut self, collation_name: &str, x_compare: C) -> Result<()> + where + C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure( + arg1: *mut c_void, + arg2: c_int, + arg3: *const c_void, + arg4: c_int, + arg5: *const c_void, + ) -> c_int + where + C: Fn(&str, &str) -> Ordering, + { + let r = catch_unwind(|| { + let boxed_f: *mut C = arg1.cast::(); + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + let s1 = { + let c_slice = slice::from_raw_parts(arg3.cast::(), arg2 as usize); + String::from_utf8_lossy(c_slice) + }; + let s2 = { + let c_slice = slice::from_raw_parts(arg5.cast::(), arg4 as usize); + String::from_utf8_lossy(c_slice) + }; + (*boxed_f)(s1.as_ref(), s2.as_ref()) + }); + let t = match r { + Err(_) => { + return -1; // FIXME How ? + } + Ok(r) => r, + }; + + match t { + Ordering::Less => -1, + Ordering::Equal => 0, + Ordering::Greater => 1, + } + } + + let boxed_f: *mut C = Box::into_raw(Box::new(x_compare)); + let c_name = str_to_cstring(collation_name)?; + let flags = ffi::SQLITE_UTF8; + let r = unsafe { + ffi::sqlite3_create_collation_v2( + self.db(), + c_name.as_ptr(), + flags, + boxed_f.cast::(), + Some(call_boxed_closure::), + Some(free_boxed_value::), + ) + }; + let res = self.decode_result(r); + // The xDestroy callback is not called if the sqlite3_create_collation_v2() + // function fails. + if res.is_err() { + drop(unsafe { Box::from_raw(boxed_f) }); + } + res + } + + fn collation_needed( + &mut self, + x_coll_needed: fn(&Connection, &str) -> Result<()>, + ) -> Result<()> { + use std::mem; + #[allow(clippy::needless_return)] + unsafe extern "C" fn collation_needed_callback( + arg1: *mut c_void, + arg2: *mut ffi::sqlite3, + e_text_rep: c_int, + arg3: *const c_char, + ) { + use std::ffi::CStr; + use std::str; + + if e_text_rep != ffi::SQLITE_UTF8 { + // TODO: validate + return; + } + + let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1); + let res = catch_unwind(|| { + let conn = Connection::from_handle(arg2).unwrap(); + let collation_name = { + let c_slice = CStr::from_ptr(arg3).to_bytes(); + str::from_utf8(c_slice).expect("illegal collation sequence name") + }; + callback(&conn, collation_name) + }); + if res.is_err() { + return; // FIXME How ? + } + } + + let r = unsafe { + ffi::sqlite3_collation_needed( + self.db(), + x_coll_needed as *mut c_void, + Some(collation_needed_callback), + ) + }; + self.decode_result(r) + } + + #[inline] + fn remove_collation(&mut self, collation_name: &str) -> Result<()> { + let c_name = str_to_cstring(collation_name)?; + let r = unsafe { + ffi::sqlite3_create_collation_v2( + self.db(), + c_name.as_ptr(), + ffi::SQLITE_UTF8, + ptr::null_mut(), + None, + None, + ) + }; + self.decode_result(r) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result}; + use fallible_streaming_iterator::FallibleStreamingIterator; + use std::cmp::Ordering; + use unicase::UniCase; + + fn unicase_compare(s1: &str, s2: &str) -> Ordering { + UniCase::new(s1).cmp(&UniCase::new(s2)) + } + + #[test] + fn test_unicase() -> Result<()> { + let db = Connection::open_in_memory()?; + + db.create_collation("unicase", unicase_compare)?; + + collate(db) + } + + fn collate(db: Connection) -> Result<()> { + db.execute_batch( + "CREATE TABLE foo (bar); + INSERT INTO foo (bar) VALUES ('Maße'); + INSERT INTO foo (bar) VALUES ('MASSE');", + )?; + let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?; + let rows = stmt.query([])?; + assert_eq!(rows.count()?, 1); + Ok(()) + } + + fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> { + if "unicase" == collation_name { + db.create_collation(collation_name, unicase_compare) + } else { + Ok(()) + } + } + + #[test] + fn test_collation_needed() -> Result<()> { + let db = Connection::open_in_memory()?; + db.collation_needed(collation_needed)?; + collate(db) + } +} diff --git a/src/column.rs b/src/column.rs new file mode 100644 index 0000000..aa1f5f7 --- /dev/null +++ b/src/column.rs @@ -0,0 +1,241 @@ +use std::str; + +use crate::{Error, Result, Statement}; + +/// Information about a column of a SQLite query. +#[derive(Debug)] +pub struct Column<'stmt> { + name: &'stmt str, + decl_type: Option<&'stmt str>, +} + +impl Column<'_> { + /// Returns the name of the column. + #[inline] + #[must_use] + pub fn name(&self) -> &str { + self.name + } + + /// Returns the type of the column (`None` for expression). + #[inline] + #[must_use] + pub fn decl_type(&self) -> Option<&str> { + self.decl_type + } +} + +impl Statement<'_> { + /// Get all the column names in the result set of the prepared statement. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + pub fn column_names(&self) -> Vec<&str> { + let n = self.column_count(); + let mut cols = Vec::with_capacity(n as usize); + for i in 0..n { + let s = self.column_name_unwrap(i); + cols.push(s); + } + cols + } + + /// Return the number of columns in the result set returned by the prepared + /// statement. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + #[inline] + pub fn column_count(&self) -> usize { + self.stmt.column_count() + } + + /// Check that column name reference lifetime is limited: + /// https://www.sqlite.org/c3ref/column_name.html + /// > The returned string pointer is valid... + /// + /// `column_name` reference can become invalid if `stmt` is reprepared + /// (because of schema change) when `query_row` is called. So we assert + /// that a compilation error happens if this reference is kept alive: + /// ```compile_fail + /// use rusqlite::{Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// let mut stmt = db.prepare("SELECT 1 as x")?; + /// let column_name = stmt.column_name(0)?; + /// let x = stmt.query_row([], |r| r.get::<_, i64>(0))?; // E0502 + /// assert_eq!(1, x); + /// assert_eq!("x", column_name); + /// Ok(()) + /// } + /// ``` + #[inline] + pub(super) fn column_name_unwrap(&self, col: usize) -> &str { + // Just panic if the bounds are wrong for now, we never call this + // without checking first. + self.column_name(col).expect("Column out of bounds") + } + + /// Returns the name assigned to a particular column in the result set + /// returned by the prepared statement. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// Panics when column name is not valid UTF-8. + #[inline] + pub fn column_name(&self, col: usize) -> Result<&str> { + self.stmt + .column_name(col) + .ok_or(Error::InvalidColumnIndex(col)) + .map(|slice| { + str::from_utf8(slice.to_bytes()).expect("Invalid UTF-8 sequence in column name") + }) + } + + /// Returns the column index in the result set for a given column name. + /// + /// If there is no AS clause then the name of the column is unspecified and + /// may change from one release of SQLite to the next. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + /// + /// # Failure + /// + /// Will return an `Error::InvalidColumnName` when there is no column with + /// the specified `name`. + #[inline] + pub fn column_index(&self, name: &str) -> Result { + let bytes = name.as_bytes(); + let n = self.column_count(); + for i in 0..n { + // Note: `column_name` is only fallible if `i` is out of bounds, + // which we've already checked. + if bytes.eq_ignore_ascii_case(self.stmt.column_name(i).unwrap().to_bytes()) { + return Ok(i); + } + } + Err(Error::InvalidColumnName(String::from(name))) + } + + /// Returns a slice describing the columns of the result of the query. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + #[cfg(feature = "column_decltype")] + #[cfg_attr(docsrs, doc(cfg(feature = "column_decltype")))] + pub fn columns(&self) -> Vec { + let n = self.column_count(); + let mut cols = Vec::with_capacity(n as usize); + for i in 0..n { + let name = self.column_name_unwrap(i); + let slice = self.stmt.column_decltype(i); + let decl_type = slice.map(|s| { + str::from_utf8(s.to_bytes()).expect("Invalid UTF-8 sequence in column declaration") + }); + cols.push(Column { name, decl_type }); + } + cols + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result}; + + #[test] + #[cfg(feature = "column_decltype")] + fn test_columns() -> Result<()> { + use super::Column; + + let db = Connection::open_in_memory()?; + let query = db.prepare("SELECT * FROM sqlite_master")?; + let columns = query.columns(); + let column_names: Vec<&str> = columns.iter().map(Column::name).collect(); + assert_eq!( + column_names.as_slice(), + &["type", "name", "tbl_name", "rootpage", "sql"] + ); + let column_types: Vec> = columns + .iter() + .map(|col| col.decl_type().map(str::to_lowercase)) + .collect(); + assert_eq!( + &column_types[..3], + &[ + Some("text".to_owned()), + Some("text".to_owned()), + Some("text".to_owned()), + ] + ); + Ok(()) + } + + #[test] + fn test_column_name_in_error() -> Result<()> { + use crate::{types::Type, Error}; + let db = Connection::open_in_memory()?; + db.execute_batch( + "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, NULL); + END;", + )?; + let mut stmt = db.prepare("SELECT x as renamed, y FROM foo")?; + let mut rows = stmt.query([])?; + let row = rows.next()?.unwrap(); + match row.get::<_, String>(0).unwrap_err() { + Error::InvalidColumnType(idx, name, ty) => { + assert_eq!(idx, 0); + assert_eq!(name, "renamed"); + assert_eq!(ty, Type::Integer); + } + e => { + panic!("Unexpected error type: {:?}", e); + } + } + match row.get::<_, String>("y").unwrap_err() { + Error::InvalidColumnType(idx, name, ty) => { + assert_eq!(idx, 1); + assert_eq!(name, "y"); + assert_eq!(ty, Type::Null); + } + e => { + panic!("Unexpected error type: {:?}", e); + } + } + Ok(()) + } + + /// `column_name` reference should stay valid until `stmt` is reprepared (or + /// reset) even if DB schema is altered (SQLite documentation is + /// ambiguous here because it says reference "is valid until (...) the next + /// call to sqlite3_column_name() or sqlite3_column_name16() on the same + /// column.". We assume that reference is valid if only + /// `sqlite3_column_name()` is used): + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_column_name_reference() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE y (x);")?; + let stmt = db.prepare("SELECT x FROM y;")?; + let column_name = stmt.column_name(0)?; + assert_eq!("x", column_name); + db.execute_batch("ALTER TABLE y RENAME COLUMN x TO z;")?; + // column name is not refreshed until statement is re-prepared + let same_column_name = stmt.column_name(0)?; + assert_eq!(same_column_name, column_name); + Ok(()) + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..b295d97 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,156 @@ +//! Configure database connections + +use std::os::raw::c_int; + +use crate::error::check; +use crate::ffi; +use crate::{Connection, Result}; + +/// Database Connection Configuration Options +/// See [Database Connection Configuration Options](https://sqlite.org/c3ref/c_dbconfig_enable_fkey.html) for details. +#[repr(i32)] +#[allow(non_snake_case, non_camel_case_types)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms)] +pub enum DbConfig { + //SQLITE_DBCONFIG_MAINDBNAME = 1000, /* const char* */ + //SQLITE_DBCONFIG_LOOKASIDE = 1001, /* void* int int */ + /// Enable or disable the enforcement of foreign key constraints. + SQLITE_DBCONFIG_ENABLE_FKEY = 1002, + /// Enable or disable triggers. + SQLITE_DBCONFIG_ENABLE_TRIGGER = 1003, + /// Enable or disable the fts3_tokenizer() function which is part of the + /// FTS3 full-text search engine extension. + SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER = 1004, // 3.12.0 + //SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION = 1005, + /// In WAL mode, enable or disable the checkpoint operation before closing + /// the connection. + SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006, // 3.16.2 + /// Activates or deactivates the query planner stability guarantee (QPSG). + SQLITE_DBCONFIG_ENABLE_QPSG = 1007, // 3.20.0 + /// Includes or excludes output for any operations performed by trigger + /// programs from the output of EXPLAIN QUERY PLAN commands. + SQLITE_DBCONFIG_TRIGGER_EQP = 1008, // 3.22.0 + /// Activates or deactivates the "reset" flag for a database connection. + /// Run VACUUM with this flag set to reset the database. + SQLITE_DBCONFIG_RESET_DATABASE = 1009, // 3.24.0 + /// Activates or deactivates the "defensive" flag for a database connection. + SQLITE_DBCONFIG_DEFENSIVE = 1010, // 3.26.0 + /// Activates or deactivates the "writable_schema" flag. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_WRITABLE_SCHEMA = 1011, // 3.28.0 + /// Activates or deactivates the legacy behavior of the ALTER TABLE RENAME + /// command. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_LEGACY_ALTER_TABLE = 1012, // 3.29 + /// Activates or deactivates the legacy double-quoted string literal + /// misfeature for DML statements only. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_DQS_DML = 1013, // 3.29.0 + /// Activates or deactivates the legacy double-quoted string literal + /// misfeature for DDL statements. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_DQS_DDL = 1014, // 3.29.0 + /// Enable or disable views. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_ENABLE_VIEW = 1015, // 3.30.0 + /// Activates or deactivates the legacy file format flag. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_LEGACY_FILE_FORMAT = 1016, // 3.31.0 + /// Tells SQLite to assume that database schemas (the contents of the + /// sqlite_master tables) are untainted by malicious content. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_TRUSTED_SCHEMA = 1017, // 3.31.0 +} + +impl Connection { + /// Returns the current value of a `config`. + /// + /// - `SQLITE_DBCONFIG_ENABLE_FKEY`: return `false` or `true` to indicate + /// whether FK enforcement is off or on + /// - `SQLITE_DBCONFIG_ENABLE_TRIGGER`: return `false` or `true` to indicate + /// whether triggers are disabled or enabled + /// - `SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER`: return `false` or `true` to + /// indicate whether `fts3_tokenizer` are disabled or enabled + /// - `SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE`: return `false` to indicate + /// checkpoints-on-close are not disabled or `true` if they are + /// - `SQLITE_DBCONFIG_ENABLE_QPSG`: return `false` or `true` to indicate + /// whether the QPSG is disabled or enabled + /// - `SQLITE_DBCONFIG_TRIGGER_EQP`: return `false` to indicate + /// output-for-trigger are not disabled or `true` if it is + #[inline] + pub fn db_config(&self, config: DbConfig) -> Result { + let c = self.db.borrow(); + unsafe { + let mut val = 0; + check(ffi::sqlite3_db_config( + c.db(), + config as c_int, + -1, + &mut val, + ))?; + Ok(val != 0) + } + } + + /// Make configuration changes to a database connection + /// + /// - `SQLITE_DBCONFIG_ENABLE_FKEY`: `false` to disable FK enforcement, + /// `true` to enable FK enforcement + /// - `SQLITE_DBCONFIG_ENABLE_TRIGGER`: `false` to disable triggers, `true` + /// to enable triggers + /// - `SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER`: `false` to disable + /// `fts3_tokenizer()`, `true` to enable `fts3_tokenizer()` + /// - `SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE`: `false` (the default) to enable + /// checkpoints-on-close, `true` to disable them + /// - `SQLITE_DBCONFIG_ENABLE_QPSG`: `false` to disable the QPSG, `true` to + /// enable QPSG + /// - `SQLITE_DBCONFIG_TRIGGER_EQP`: `false` to disable output for trigger + /// programs, `true` to enable it + #[inline] + pub fn set_db_config(&self, config: DbConfig, new_val: bool) -> Result { + let c = self.db.borrow_mut(); + unsafe { + let mut val = 0; + check(ffi::sqlite3_db_config( + c.db(), + config as c_int, + if new_val { 1 } else { 0 }, + &mut val, + ))?; + Ok(val != 0) + } + } +} + +#[cfg(test)] +mod test { + use super::DbConfig; + use crate::{Connection, Result}; + + #[test] + fn test_db_config() -> Result<()> { + let db = Connection::open_in_memory()?; + + let opposite = !db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY)?; + assert_eq!( + db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY, opposite), + Ok(opposite) + ); + assert_eq!( + db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY), + Ok(opposite) + ); + + let opposite = !db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER)?; + assert_eq!( + db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER, opposite), + Ok(opposite) + ); + assert_eq!( + db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER), + Ok(opposite) + ); + Ok(()) + } +} diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..bcaefc9 --- /dev/null +++ b/src/context.rs @@ -0,0 +1,75 @@ +//! Code related to `sqlite3_context` common to `functions` and `vtab` modules. + +use std::os::raw::{c_int, c_void}; +#[cfg(feature = "array")] +use std::rc::Rc; + +use crate::ffi; +use crate::ffi::sqlite3_context; + +use crate::str_for_sqlite; +use crate::types::{ToSqlOutput, ValueRef}; +#[cfg(feature = "array")] +use crate::vtab::array::{free_array, ARRAY_TYPE}; + +// This function is inline despite it's size because what's in the ToSqlOutput +// is often known to the compiler, and thus const prop/DCE can substantially +// simplify the function. +#[inline] +pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<'_>) { + let value = match *result { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(len) => { + // TODO sqlite3_result_zeroblob64 // 3.8.11 + return ffi::sqlite3_result_zeroblob(ctx, len); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(ref a) => { + return ffi::sqlite3_result_pointer( + ctx, + Rc::into_raw(a.clone()) as *mut c_void, + ARRAY_TYPE, + Some(free_array), + ); + } + }; + + match value { + ValueRef::Null => ffi::sqlite3_result_null(ctx), + ValueRef::Integer(i) => ffi::sqlite3_result_int64(ctx, i), + ValueRef::Real(r) => ffi::sqlite3_result_double(ctx, r), + ValueRef::Text(s) => { + let length = s.len(); + if length > c_int::MAX as usize { + ffi::sqlite3_result_error_toobig(ctx); + } else { + let (c_str, len, destructor) = match str_for_sqlite(s) { + Ok(c_str) => c_str, + // TODO sqlite3_result_error + Err(_) => return ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), + }; + // TODO sqlite3_result_text64 // 3.8.7 + ffi::sqlite3_result_text(ctx, c_str, len, destructor); + } + } + ValueRef::Blob(b) => { + let length = b.len(); + if length > c_int::MAX as usize { + ffi::sqlite3_result_error_toobig(ctx); + } else if length == 0 { + ffi::sqlite3_result_zeroblob(ctx, 0); + } else { + // TODO sqlite3_result_blob64 // 3.8.7 + ffi::sqlite3_result_blob( + ctx, + b.as_ptr().cast::(), + length as c_int, + ffi::SQLITE_TRANSIENT(), + ); + } + } + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..3c264d3 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,445 @@ +use crate::types::FromSqlError; +use crate::types::Type; +use crate::{errmsg_to_string, ffi, Result}; +use std::error; +use std::fmt; +use std::os::raw::c_int; +use std::path::PathBuf; +use std::str; + +/// Enum listing possible errors from rusqlite. +#[derive(Debug)] +#[allow(clippy::enum_variant_names)] +#[non_exhaustive] +pub enum Error { + /// An error from an underlying SQLite call. + SqliteFailure(ffi::Error, Option), + + /// Error reported when attempting to open a connection when SQLite was + /// configured to allow single-threaded use only. + SqliteSingleThreadedMode, + + /// Error when the value of a particular column is requested, but it cannot + /// be converted to the requested Rust type. + FromSqlConversionFailure(usize, Type, Box), + + /// Error when SQLite gives us an integral value outside the range of the + /// requested type (e.g., trying to get the value 1000 into a `u8`). + /// The associated `usize` is the column index, + /// and the associated `i64` is the value returned by SQLite. + IntegralValueOutOfRange(usize, i64), + + /// Error converting a string to UTF-8. + Utf8Error(str::Utf8Error), + + /// Error converting a string to a C-compatible string because it contained + /// an embedded nul. + NulError(std::ffi::NulError), + + /// Error when using SQL named parameters and passing a parameter name not + /// present in the SQL. + InvalidParameterName(String), + + /// Error converting a file path to a string. + InvalidPath(PathBuf), + + /// Error returned when an [`execute`](crate::Connection::execute) call + /// returns rows. + ExecuteReturnedResults, + + /// Error when a query that was expected to return at least one row (e.g., + /// for [`query_row`](crate::Connection::query_row)) did not return any. + QueryReturnedNoRows, + + /// Error when the value of a particular column is requested, but the index + /// is out of range for the statement. + InvalidColumnIndex(usize), + + /// Error when the value of a named column is requested, but no column + /// matches the name for the statement. + InvalidColumnName(String), + + /// Error when the value of a particular column is requested, but the type + /// of the result in that column cannot be converted to the requested + /// Rust type. + InvalidColumnType(usize, String, Type), + + /// Error when a query that was expected to insert one row did not insert + /// any or insert many. + StatementChangedRows(usize), + + /// Error returned by + /// [`functions::Context::get`](crate::functions::Context::get) when the + /// function argument cannot be converted to the requested type. + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + InvalidFunctionParameterType(usize, Type), + /// Error returned by [`vtab::Values::get`](crate::vtab::Values::get) when + /// the filter argument cannot be converted to the requested type. + #[cfg(feature = "vtab")] + #[cfg_attr(docsrs, doc(cfg(feature = "vtab")))] + InvalidFilterParameterType(usize, Type), + + /// An error case available for implementors of custom user functions (e.g., + /// [`create_scalar_function`](crate::Connection::create_scalar_function)). + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + #[allow(dead_code)] + UserFunctionError(Box), + + /// Error available for the implementors of the + /// [`ToSql`](crate::types::ToSql) trait. + ToSqlConversionFailure(Box), + + /// Error when the SQL is not a `SELECT`, is not read-only. + InvalidQuery, + + /// An error case available for implementors of custom modules (e.g., + /// [`create_module`](crate::Connection::create_module)). + #[cfg(feature = "vtab")] + #[cfg_attr(docsrs, doc(cfg(feature = "vtab")))] + #[allow(dead_code)] + ModuleError(String), + + /// An unwinding panic occurs in an UDF (user-defined function). + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + UnwindingPanic, + + /// An error returned when + /// [`Context::get_aux`](crate::functions::Context::get_aux) attempts to + /// retrieve data of a different type than what had been stored using + /// [`Context::set_aux`](crate::functions::Context::set_aux). + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + GetAuxWrongType, + + /// Error when the SQL contains multiple statements. + MultipleStatement, + /// Error when the number of bound parameters does not match the number of + /// parameters in the query. The first `usize` is how many parameters were + /// given, the 2nd is how many were expected. + InvalidParameterCount(usize, usize), + + /// Returned from various functions in the Blob IO positional API. For + /// example, + /// [`Blob::raw_read_at_exact`](crate::blob::Blob::raw_read_at_exact) will + /// return it if the blob has insufficient data. + #[cfg(feature = "blob")] + #[cfg_attr(docsrs, doc(cfg(feature = "blob")))] + BlobSizeError, + /// Error referencing a specific token in the input SQL + #[cfg(feature = "modern_sqlite")] // 3.38.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + SqlInputError { + /// error code + error: ffi::Error, + /// error message + msg: String, + /// SQL input + sql: String, + /// byte offset of the start of invalid token + offset: c_int, + }, +} + +impl PartialEq for Error { + fn eq(&self, other: &Error) -> bool { + match (self, other) { + (Error::SqliteFailure(e1, s1), Error::SqliteFailure(e2, s2)) => e1 == e2 && s1 == s2, + (Error::SqliteSingleThreadedMode, Error::SqliteSingleThreadedMode) => true, + (Error::IntegralValueOutOfRange(i1, n1), Error::IntegralValueOutOfRange(i2, n2)) => { + i1 == i2 && n1 == n2 + } + (Error::Utf8Error(e1), Error::Utf8Error(e2)) => e1 == e2, + (Error::NulError(e1), Error::NulError(e2)) => e1 == e2, + (Error::InvalidParameterName(n1), Error::InvalidParameterName(n2)) => n1 == n2, + (Error::InvalidPath(p1), Error::InvalidPath(p2)) => p1 == p2, + (Error::ExecuteReturnedResults, Error::ExecuteReturnedResults) => true, + (Error::QueryReturnedNoRows, Error::QueryReturnedNoRows) => true, + (Error::InvalidColumnIndex(i1), Error::InvalidColumnIndex(i2)) => i1 == i2, + (Error::InvalidColumnName(n1), Error::InvalidColumnName(n2)) => n1 == n2, + (Error::InvalidColumnType(i1, n1, t1), Error::InvalidColumnType(i2, n2, t2)) => { + i1 == i2 && t1 == t2 && n1 == n2 + } + (Error::StatementChangedRows(n1), Error::StatementChangedRows(n2)) => n1 == n2, + #[cfg(feature = "functions")] + ( + Error::InvalidFunctionParameterType(i1, t1), + Error::InvalidFunctionParameterType(i2, t2), + ) => i1 == i2 && t1 == t2, + #[cfg(feature = "vtab")] + ( + Error::InvalidFilterParameterType(i1, t1), + Error::InvalidFilterParameterType(i2, t2), + ) => i1 == i2 && t1 == t2, + (Error::InvalidQuery, Error::InvalidQuery) => true, + #[cfg(feature = "vtab")] + (Error::ModuleError(s1), Error::ModuleError(s2)) => s1 == s2, + #[cfg(feature = "functions")] + (Error::UnwindingPanic, Error::UnwindingPanic) => true, + #[cfg(feature = "functions")] + (Error::GetAuxWrongType, Error::GetAuxWrongType) => true, + (Error::InvalidParameterCount(i1, n1), Error::InvalidParameterCount(i2, n2)) => { + i1 == i2 && n1 == n2 + } + #[cfg(feature = "blob")] + (Error::BlobSizeError, Error::BlobSizeError) => true, + #[cfg(feature = "modern_sqlite")] + ( + Error::SqlInputError { + error: e1, + msg: m1, + sql: s1, + offset: o1, + }, + Error::SqlInputError { + error: e2, + msg: m2, + sql: s2, + offset: o2, + }, + ) => e1 == e2 && m1 == m2 && s1 == s2 && o1 == o2, + (..) => false, + } + } +} + +impl From for Error { + #[cold] + fn from(err: str::Utf8Error) -> Error { + Error::Utf8Error(err) + } +} + +impl From for Error { + #[cold] + fn from(err: std::ffi::NulError) -> Error { + Error::NulError(err) + } +} + +const UNKNOWN_COLUMN: usize = usize::MAX; + +/// The conversion isn't precise, but it's convenient to have it +/// to allow use of `get_raw(…).as_…()?` in callbacks that take `Error`. +impl From for Error { + #[cold] + fn from(err: FromSqlError) -> Error { + // The error type requires index and type fields, but they aren't known in this + // context. + match err { + FromSqlError::OutOfRange(val) => Error::IntegralValueOutOfRange(UNKNOWN_COLUMN, val), + FromSqlError::InvalidBlobSize { .. } => { + Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Blob, Box::new(err)) + } + FromSqlError::Other(source) => { + Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Null, source) + } + _ => Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Null, Box::new(err)), + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Error::SqliteFailure(ref err, None) => err.fmt(f), + Error::SqliteFailure(_, Some(ref s)) => write!(f, "{}", s), + Error::SqliteSingleThreadedMode => write!( + f, + "SQLite was compiled or configured for single-threaded use only" + ), + Error::FromSqlConversionFailure(i, ref t, ref err) => { + if i != UNKNOWN_COLUMN { + write!( + f, + "Conversion error from type {} at index: {}, {}", + t, i, err + ) + } else { + err.fmt(f) + } + } + Error::IntegralValueOutOfRange(col, val) => { + if col != UNKNOWN_COLUMN { + write!(f, "Integer {} out of range at index {}", val, col) + } else { + write!(f, "Integer {} out of range", val) + } + } + Error::Utf8Error(ref err) => err.fmt(f), + Error::NulError(ref err) => err.fmt(f), + Error::InvalidParameterName(ref name) => write!(f, "Invalid parameter name: {}", name), + Error::InvalidPath(ref p) => write!(f, "Invalid path: {}", p.to_string_lossy()), + Error::ExecuteReturnedResults => { + write!(f, "Execute returned results - did you mean to call query?") + } + Error::QueryReturnedNoRows => write!(f, "Query returned no rows"), + Error::InvalidColumnIndex(i) => write!(f, "Invalid column index: {}", i), + Error::InvalidColumnName(ref name) => write!(f, "Invalid column name: {}", name), + Error::InvalidColumnType(i, ref name, ref t) => write!( + f, + "Invalid column type {} at index: {}, name: {}", + t, i, name + ), + Error::InvalidParameterCount(i1, n1) => write!( + f, + "Wrong number of parameters passed to query. Got {}, needed {}", + i1, n1 + ), + Error::StatementChangedRows(i) => write!(f, "Query changed {} rows", i), + + #[cfg(feature = "functions")] + Error::InvalidFunctionParameterType(i, ref t) => { + write!(f, "Invalid function parameter type {} at index {}", t, i) + } + #[cfg(feature = "vtab")] + Error::InvalidFilterParameterType(i, ref t) => { + write!(f, "Invalid filter parameter type {} at index {}", t, i) + } + #[cfg(feature = "functions")] + Error::UserFunctionError(ref err) => err.fmt(f), + Error::ToSqlConversionFailure(ref err) => err.fmt(f), + Error::InvalidQuery => write!(f, "Query is not read-only"), + #[cfg(feature = "vtab")] + Error::ModuleError(ref desc) => write!(f, "{}", desc), + #[cfg(feature = "functions")] + Error::UnwindingPanic => write!(f, "unwinding panic"), + #[cfg(feature = "functions")] + Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"), + Error::MultipleStatement => write!(f, "Multiple statements provided"), + #[cfg(feature = "blob")] + Error::BlobSizeError => "Blob size is insufficient".fmt(f), + #[cfg(feature = "modern_sqlite")] + Error::SqlInputError { + ref msg, + offset, + ref sql, + .. + } => write!(f, "{} in {} at offset {}", msg, sql, offset), + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match *self { + Error::SqliteFailure(ref err, _) => Some(err), + Error::Utf8Error(ref err) => Some(err), + Error::NulError(ref err) => Some(err), + + Error::IntegralValueOutOfRange(..) + | Error::SqliteSingleThreadedMode + | Error::InvalidParameterName(_) + | Error::ExecuteReturnedResults + | Error::QueryReturnedNoRows + | Error::InvalidColumnIndex(_) + | Error::InvalidColumnName(_) + | Error::InvalidColumnType(..) + | Error::InvalidPath(_) + | Error::InvalidParameterCount(..) + | Error::StatementChangedRows(_) + | Error::InvalidQuery + | Error::MultipleStatement => None, + + #[cfg(feature = "functions")] + Error::InvalidFunctionParameterType(..) => None, + #[cfg(feature = "vtab")] + Error::InvalidFilterParameterType(..) => None, + + #[cfg(feature = "functions")] + Error::UserFunctionError(ref err) => Some(&**err), + + Error::FromSqlConversionFailure(_, _, ref err) + | Error::ToSqlConversionFailure(ref err) => Some(&**err), + + #[cfg(feature = "vtab")] + Error::ModuleError(_) => None, + + #[cfg(feature = "functions")] + Error::UnwindingPanic => None, + + #[cfg(feature = "functions")] + Error::GetAuxWrongType => None, + + #[cfg(feature = "blob")] + Error::BlobSizeError => None, + #[cfg(feature = "modern_sqlite")] + Error::SqlInputError { ref error, .. } => Some(error), + } + } +} + +impl Error { + /// Returns the underlying SQLite error if this is [`Error::SqliteFailure`]. + #[inline] + pub fn sqlite_error(&self) -> Option<&ffi::Error> { + match self { + Self::SqliteFailure(error, _) => Some(error), + _ => None, + } + } + + /// Returns the underlying SQLite error code if this is + /// [`Error::SqliteFailure`]. + #[inline] + pub fn sqlite_error_code(&self) -> Option { + self.sqlite_error().map(|error| error.code) + } +} + +// These are public but not re-exported by lib.rs, so only visible within crate. + +#[cold] +pub fn error_from_sqlite_code(code: c_int, message: Option) -> Error { + // TODO sqlite3_error_offset // 3.38.0, #1130 + Error::SqliteFailure(ffi::Error::new(code), message) +} + +#[cold] +pub unsafe fn error_from_handle(db: *mut ffi::sqlite3, code: c_int) -> Error { + let message = if db.is_null() { + None + } else { + Some(errmsg_to_string(ffi::sqlite3_errmsg(db))) + }; + error_from_sqlite_code(code, message) +} + +#[cold] +#[cfg(not(all(feature = "modern_sqlite", not(feature = "bundled-sqlcipher"))))] // SQLite >= 3.38.0 +pub unsafe fn error_with_offset(db: *mut ffi::sqlite3, code: c_int, _sql: &str) -> Error { + error_from_handle(db, code) +} + +#[cold] +#[cfg(all(feature = "modern_sqlite", not(feature = "bundled-sqlcipher")))] // SQLite >= 3.38.0 +pub unsafe fn error_with_offset(db: *mut ffi::sqlite3, code: c_int, sql: &str) -> Error { + if db.is_null() { + error_from_sqlite_code(code, None) + } else { + let error = ffi::Error::new(code); + let msg = errmsg_to_string(ffi::sqlite3_errmsg(db)); + if ffi::ErrorCode::Unknown == error.code { + let offset = ffi::sqlite3_error_offset(db); + if offset >= 0 { + return Error::SqlInputError { + error, + msg, + sql: sql.to_owned(), + offset, + }; + } + } + Error::SqliteFailure(error, Some(msg)) + } +} + +pub fn check(code: c_int) -> Result<()> { + if code != crate::ffi::SQLITE_OK { + Err(error_from_sqlite_code(code, None)) + } else { + Ok(()) + } +} diff --git a/src/functions.rs b/src/functions.rs new file mode 100644 index 0000000..138baac --- /dev/null +++ b/src/functions.rs @@ -0,0 +1,1099 @@ +//! Create or redefine SQL functions. +//! +//! # Example +//! +//! Adding a `regexp` function to a connection in which compiled regular +//! expressions are cached in a `HashMap`. For an alternative implementation +//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface +//! to avoid recompiling regular expressions, see the unit tests for this +//! module. +//! +//! ```rust +//! use regex::Regex; +//! use rusqlite::functions::FunctionFlags; +//! use rusqlite::{Connection, Error, Result}; +//! use std::sync::Arc; +//! type BoxError = Box; +//! +//! fn add_regexp_function(db: &Connection) -> Result<()> { +//! db.create_scalar_function( +//! "regexp", +//! 2, +//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, +//! move |ctx| { +//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); +//! let regexp: Arc = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> { +//! Ok(Regex::new(vr.as_str()?)?) +//! })?; +//! let is_match = { +//! let text = ctx +//! .get_raw(1) +//! .as_str() +//! .map_err(|e| Error::UserFunctionError(e.into()))?; +//! +//! regexp.is_match(text) +//! }; +//! +//! Ok(is_match) +//! }, +//! ) +//! } +//! +//! fn main() -> Result<()> { +//! let db = Connection::open_in_memory()?; +//! add_regexp_function(&db)?; +//! +//! let is_match: bool = +//! db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| { +//! row.get(0) +//! })?; +//! +//! assert!(is_match); +//! Ok(()) +//! } +//! ``` +use std::any::Any; +use std::marker::PhantomData; +use std::ops::Deref; +use std::os::raw::{c_int, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; +use std::ptr; +use std::slice; +use std::sync::Arc; + +use crate::ffi; +use crate::ffi::sqlite3_context; +use crate::ffi::sqlite3_value; + +use crate::context::set_result; +use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; + +use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; + +unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) { + // Extended constraint error codes were added in SQLite 3.7.16. We don't have + // an explicit feature check for that, and this doesn't really warrant one. + // We'll use the extended code if we're on the bundled version (since it's + // at least 3.17.0) and the normal constraint error code if not. + #[cfg(feature = "modern_sqlite")] + fn constraint_error_code() -> i32 { + ffi::SQLITE_CONSTRAINT_FUNCTION + } + #[cfg(not(feature = "modern_sqlite"))] + fn constraint_error_code() -> i32 { + ffi::SQLITE_CONSTRAINT + } + + if let Error::SqliteFailure(ref err, ref s) = *err { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } else { + ffi::sqlite3_result_error_code(ctx, constraint_error_code()); + if let Ok(cstr) = str_to_cstring(&err.to_string()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } +} + +unsafe extern "C" fn free_boxed_value(p: *mut c_void) { + drop(Box::from_raw(p.cast::())); +} + +/// Context is a wrapper for the SQLite function +/// evaluation context. +pub struct Context<'a> { + ctx: *mut sqlite3_context, + args: &'a [*mut sqlite3_value], +} + +impl Context<'_> { + /// Returns the number of arguments to the function. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.args.len() + } + + /// Returns `true` when there is no argument. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.args.is_empty() + } + + /// Returns the `idx`th argument as a `T`. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to + /// [`self.len()`](Context::len). + /// + /// Will return Err if the underlying SQLite type cannot be converted to a + /// `T`. + pub fn get(&self, idx: usize) -> Result { + let arg = self.args[idx]; + let value = unsafe { ValueRef::from_value(arg) }; + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => { + Error::InvalidFunctionParameterType(idx, value.data_type()) + } + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx, value.data_type(), err) + } + FromSqlError::InvalidBlobSize { .. } => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + }) + } + + /// Returns the `idx`th argument as a `ValueRef`. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to + /// [`self.len()`](Context::len). + #[inline] + #[must_use] + pub fn get_raw(&self, idx: usize) -> ValueRef<'_> { + let arg = self.args[idx]; + unsafe { ValueRef::from_value(arg) } + } + + /// Returns the subtype of `idx`th argument. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to + /// [`self.len()`](Context::len). + #[cfg(feature = "modern_sqlite")] // 3.9.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn get_subtype(&self, idx: usize) -> std::os::raw::c_uint { + let arg = self.args[idx]; + unsafe { ffi::sqlite3_value_subtype(arg) } + } + + /// Fetch or insert the auxiliary data associated with a particular + /// parameter. This is intended to be an easier-to-use way of fetching it + /// compared to calling [`get_aux`](Context::get_aux) and + /// [`set_aux`](Context::set_aux) separately. + /// + /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of + /// this feature, or the unit tests of this module for an example. + pub fn get_or_create_aux(&self, arg: c_int, func: F) -> Result> + where + T: Send + Sync + 'static, + E: Into>, + F: FnOnce(ValueRef<'_>) -> Result, + { + if let Some(v) = self.get_aux(arg)? { + Ok(v) + } else { + let vr = self.get_raw(arg as usize); + self.set_aux( + arg, + func(vr).map_err(|e| Error::UserFunctionError(e.into()))?, + ) + } + } + + /// Sets the auxiliary data associated with a particular parameter. See + /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of + /// this feature, or the unit tests of this module for an example. + pub fn set_aux(&self, arg: c_int, value: T) -> Result> { + let orig: Arc = Arc::new(value); + let inner: AuxInner = orig.clone(); + let outer = Box::new(inner); + let raw: *mut AuxInner = Box::into_raw(outer); + unsafe { + ffi::sqlite3_set_auxdata( + self.ctx, + arg, + raw.cast(), + Some(free_boxed_value::), + ); + }; + Ok(orig) + } + + /// Gets the auxiliary data that was associated with a given parameter via + /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been + /// associated, and Ok(Some(v)) if it has. Returns an error if the + /// requested type does not match. + pub fn get_aux(&self, arg: c_int) -> Result>> { + let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner }; + if p.is_null() { + Ok(None) + } else { + let v: AuxInner = AuxInner::clone(unsafe { &*p }); + v.downcast::() + .map(Some) + .map_err(|_| Error::GetAuxWrongType) + } + } + + /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html) + /// + /// # Safety + /// + /// This function is marked unsafe because there is a potential for other + /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213). + pub unsafe fn get_connection(&self) -> Result> { + let handle = ffi::sqlite3_context_db_handle(self.ctx); + Ok(ConnectionRef { + conn: Connection::from_handle(handle)?, + phantom: PhantomData, + }) + } + + /// Set the Subtype of an SQL function + #[cfg(feature = "modern_sqlite")] // 3.9.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) { + unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) }; + } +} + +/// A reference to a connection handle with a lifetime bound to something. +pub struct ConnectionRef<'ctx> { + // comes from Connection::from_handle(sqlite3_context_db_handle(...)) + // and is non-owning + conn: Connection, + phantom: PhantomData<&'ctx Context<'ctx>>, +} + +impl Deref for ConnectionRef<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + &self.conn + } +} + +type AuxInner = Arc; + +/// Aggregate is the callback interface for user-defined +/// aggregate function. +/// +/// `A` is the type of the aggregation context and `T` is the type of the final +/// result. Implementations should be stateless. +pub trait Aggregate +where + A: RefUnwindSafe + UnwindSafe, + T: ToSql, +{ + /// Initializes the aggregation context. Will be called prior to the first + /// call to [`step()`](Aggregate::step) to set up the context for an + /// invocation of the function. (Note: `init()` will not be called if + /// there are no rows.) + fn init(&self, _: &mut Context<'_>) -> Result; + + /// "step" function called once for each row in an aggregate group. May be + /// called 0 times if there are no rows. + fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; + + /// Computes and returns the final result. Will be called exactly once for + /// each invocation of the function. If [`step()`](Aggregate::step) was + /// called at least once, will be given `Some(A)` (the same `A` as was + /// created by [`init`](Aggregate::init) and given to + /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not + /// called (because the function is running against 0 rows), will be + /// given `None`. + /// + /// The passed context will have no arguments. + fn finalize(&self, _: &mut Context<'_>, _: Option) -> Result; +} + +/// `WindowAggregate` is the callback interface for +/// user-defined aggregate window function. +#[cfg(feature = "window")] +#[cfg_attr(docsrs, doc(cfg(feature = "window")))] +pub trait WindowAggregate: Aggregate +where + A: RefUnwindSafe + UnwindSafe, + T: ToSql, +{ + /// Returns the current value of the aggregate. Unlike xFinal, the + /// implementation should not delete any context. + fn value(&self, _: Option<&A>) -> Result; + + /// Removes a row from the current window. + fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; +} + +bitflags::bitflags! { + /// Function Flags. + /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html) + /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details. + #[repr(C)] + pub struct FunctionFlags: ::std::os::raw::c_int { + /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF8 = ffi::SQLITE_UTF8; + /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE; + /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE; + /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16 = ffi::SQLITE_UTF16; + /// Means that the function always gives the same output when the input parameters are the same. + const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; // 3.8.3 + /// Means that the function may only be invoked from top-level SQL. + const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0 + /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the sub-types of its arguments. + const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0 + /// Means that the function is unlikely to cause problems even if misused. + const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0 + } +} + +impl Default for FunctionFlags { + #[inline] + fn default() -> FunctionFlags { + FunctionFlags::SQLITE_UTF8 + } +} + +impl Connection { + /// Attach a user-defined scalar function to + /// this database connection. + /// + /// `fn_name` is the name the function will be accessible from SQL. + /// `n_arg` is the number of arguments to the function. Use `-1` for a + /// variable number. If the function always returns the same value + /// given the same input, `deterministic` should be `true`. + /// + /// The function will remain available until the connection is closed or + /// until it is explicitly removed via + /// [`remove_function`](Connection::remove_function). + /// + /// # Example + /// + /// ```rust + /// # use rusqlite::{Connection, Result}; + /// # use rusqlite::functions::FunctionFlags; + /// fn scalar_function_example(db: Connection) -> Result<()> { + /// db.create_scalar_function( + /// "halve", + /// 1, + /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + /// |ctx| { + /// let value = ctx.get::(0)?; + /// Ok(value / 2f64) + /// }, + /// )?; + /// + /// let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?; + /// assert_eq!(six_halved, 3f64); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. + #[inline] + pub fn create_scalar_function( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + x_func: F, + ) -> Result<()> + where + F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, + T: ToSql, + { + self.db + .borrow_mut() + .create_scalar_function(fn_name, n_arg, flags, x_func) + } + + /// Attach a user-defined aggregate function to this + /// database connection. + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. + #[inline] + pub fn create_aggregate_function( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: D, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate + 'static, + T: ToSql, + { + self.db + .borrow_mut() + .create_aggregate_function(fn_name, n_arg, flags, aggr) + } + + /// Attach a user-defined aggregate window function to + /// this database connection. + /// + /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more + /// information. + #[cfg(feature = "window")] + #[cfg_attr(docsrs, doc(cfg(feature = "window")))] + #[inline] + pub fn create_window_function( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: W, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate + 'static, + T: ToSql, + { + self.db + .borrow_mut() + .create_window_function(fn_name, n_arg, flags, aggr) + } + + /// Removes a user-defined function from this + /// database connection. + /// + /// `fn_name` and `n_arg` should match the name and number of arguments + /// given to [`create_scalar_function`](Connection::create_scalar_function) + /// or [`create_aggregate_function`](Connection::create_aggregate_function). + /// + /// # Failure + /// + /// Will return Err if the function could not be removed. + #[inline] + pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> { + self.db.borrow_mut().remove_function(fn_name, n_arg) + } +} + +impl InnerConnection { + fn create_scalar_function( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + x_func: F, + ) -> Result<()> + where + F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, + T: ToSql, + { + unsafe extern "C" fn call_boxed_closure( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) where + F: FnMut(&Context<'_>) -> Result, + T: ToSql, + { + let r = catch_unwind(|| { + let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::(); + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + let ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_f)(&ctx) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } + } + + let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_f.cast::(), + Some(call_boxed_closure::), + None, + None, + Some(free_boxed_value::), + ) + }; + self.decode_result(r) + } + + fn create_aggregate_function( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: D, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate + 'static, + T: ToSql, + { + let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_aggr.cast::(), + None, + Some(call_boxed_step::), + Some(call_boxed_final::), + Some(free_boxed_value::), + ) + }; + self.decode_result(r) + } + + #[cfg(feature = "window")] + fn create_window_function( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: W, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate + 'static, + T: ToSql, + { + let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_window_function( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_aggr.cast::(), + Some(call_boxed_step::), + Some(call_boxed_final::), + Some(call_boxed_value::), + Some(call_boxed_inverse::), + Some(free_boxed_value::), + ) + }; + self.decode_result(r) + } + + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + ffi::SQLITE_UTF8, + ptr::null_mut(), + None, + None, + None, + None, + ) + }; + self.decode_result(r) + } +} + +unsafe fn aggregate_context(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> { + let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A; + if pac.is_null() { + return None; + } + Some(pac) +} + +unsafe extern "C" fn call_boxed_step( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate, + T: ToSql, +{ + let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { + pac + } else { + ffi::sqlite3_result_error_nomem(ctx); + return; + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::(); + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + + if (*pac as *mut A).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?)); + } + + (*boxed_aggr).step(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + match r { + Ok(_) => {} + Err(err) => report_error(ctx, &err), + }; +} + +#[cfg(feature = "window")] +unsafe extern "C" fn call_boxed_inverse( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate, + T: ToSql, +{ + let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { + pac + } else { + ffi::sqlite3_result_error_nomem(ctx); + return; + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::(); + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_aggr).inverse(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + match r { + Ok(_) => {} + Err(err) => report_error(ctx, &err), + }; +} + +unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) +where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate, + T: ToSql, +{ + // Within the xFinal callback, it is customary to set N=0 in calls to + // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. + let a: Option = match aggregate_context(ctx, 0) { + Some(pac) => { + if (*pac as *mut A).is_null() { + None + } else { + let a = Box::from_raw(*pac); + Some(*a) + } + } + None => None, + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::(); + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + let mut ctx = Context { ctx, args: &mut [] }; + (*boxed_aggr).finalize(&mut ctx, a) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } +} + +#[cfg(feature = "window")] +unsafe extern "C" fn call_boxed_value(ctx: *mut sqlite3_context) +where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate, + T: ToSql, +{ + // Within the xValue callback, it is customary to set N=0 in calls to + // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. + let a: Option<&A> = match aggregate_context(ctx, 0) { + Some(pac) => { + if (*pac as *mut A).is_null() { + None + } else { + let a = &**pac; + Some(a) + } + } + None => None, + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::(); + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + (*boxed_aggr).value(a) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } +} + +#[cfg(test)] +mod test { + use regex::Regex; + use std::os::raw::c_double; + + #[cfg(feature = "window")] + use crate::functions::WindowAggregate; + use crate::functions::{Aggregate, Context, FunctionFlags}; + use crate::{Connection, Error, Result}; + + fn half(ctx: &Context<'_>) -> Result { + assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); + let value = ctx.get::(0)?; + Ok(value / 2f64) + } + + #[test] + fn test_function_half() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_scalar_function( + "half", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + half, + )?; + let result: Result = db.query_row("SELECT half(6)", [], |r| r.get(0)); + + assert!((3f64 - result?).abs() < f64::EPSILON); + Ok(()) + } + + #[test] + fn test_remove_function() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_scalar_function( + "half", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + half, + )?; + let result: Result = db.query_row("SELECT half(6)", [], |r| r.get(0)); + assert!((3f64 - result?).abs() < f64::EPSILON); + + db.remove_function("half", 1)?; + let result: Result = db.query_row("SELECT half(6)", [], |r| r.get(0)); + assert!(result.is_err()); + Ok(()) + } + + // This implementation of a regexp scalar function uses SQLite's auxiliary data + // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular + // expression multiple times within one query. + fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result { + assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); + type BoxError = Box; + let regexp: std::sync::Arc = ctx + .get_or_create_aux(0, |vr| -> Result<_, BoxError> { + Ok(Regex::new(vr.as_str()?)?) + })?; + + let is_match = { + let text = ctx + .get_raw(1) + .as_str() + .map_err(|e| Error::UserFunctionError(e.into()))?; + + regexp.is_match(text) + }; + + Ok(is_match) + } + + #[test] + fn test_function_regexp_with_auxilliary() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch( + "BEGIN; + CREATE TABLE foo (x string); + INSERT INTO foo VALUES ('lisa'); + INSERT INTO foo VALUES ('lXsi'); + INSERT INTO foo VALUES ('lisX'); + END;", + )?; + db.create_scalar_function( + "regexp", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + regexp_with_auxilliary, + )?; + + let result: Result = + db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", [], |r| r.get(0)); + + assert!(result?); + + let result: Result = db.query_row( + "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", + [], + |r| r.get(0), + ); + + assert_eq!(2, result?); + Ok(()) + } + + #[test] + fn test_varargs_function() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_scalar_function( + "my_concat", + -1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + |ctx| { + let mut ret = String::new(); + + for idx in 0..ctx.len() { + let s = ctx.get::(idx)?; + ret.push_str(&s); + } + + Ok(ret) + }, + )?; + + for &(expected, query) in &[ + ("", "SELECT my_concat()"), + ("onetwo", "SELECT my_concat('one', 'two')"), + ("abc", "SELECT my_concat('a', 'b', 'c')"), + ] { + let result: String = db.query_row(query, [], |r| r.get(0))?; + assert_eq!(expected, result); + } + Ok(()) + } + + #[test] + fn test_get_aux_type_checking() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { + if !ctx.get::(1)? { + ctx.set_aux::(0, 100)?; + } else { + assert_eq!(ctx.get_aux::(0), Err(Error::GetAuxWrongType)); + assert_eq!(*ctx.get_aux::(0)?.unwrap(), 100); + } + Ok(true) + })?; + + let res: bool = db.query_row( + "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)", + [], + |r| r.get(0), + )?; + // Doesn't actually matter, we'll assert in the function if there's a problem. + assert!(res); + Ok(()) + } + + struct Sum; + struct Count; + + impl Aggregate> for Sum { + fn init(&self, _: &mut Context<'_>) -> Result { + Ok(0) + } + + fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum += ctx.get::(0)?; + Ok(()) + } + + fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result> { + Ok(sum) + } + } + + impl Aggregate for Count { + fn init(&self, _: &mut Context<'_>) -> Result { + Ok(0) + } + + fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum += 1; + Ok(()) + } + + fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result { + Ok(sum.unwrap_or(0)) + } + } + + #[test] + fn test_sum() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_aggregate_function( + "my_sum", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Sum, + )?; + + // sum should return NULL when given no columns (contrast with count below) + let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: Option = db.query_row(no_result, [], |r| r.get(0))?; + assert!(result.is_none()); + + let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?; + assert_eq!(4, result); + + let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ + 2, 1)"; + let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?; + assert_eq!((4, 2), result); + Ok(()) + } + + #[test] + fn test_count() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_aggregate_function( + "my_count", + -1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Count, + )?; + + // count should return 0 when given no columns (contrast with sum above) + let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: i64 = db.query_row(no_result, [], |r| r.get(0))?; + assert_eq!(result, 0); + + let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?; + assert_eq!(2, result); + Ok(()) + } + + #[cfg(feature = "window")] + impl WindowAggregate> for Sum { + fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum -= ctx.get::(0)?; + Ok(()) + } + + fn value(&self, sum: Option<&i64>) -> Result> { + Ok(sum.copied()) + } + } + + #[test] + #[cfg(feature = "window")] + fn test_window() -> Result<()> { + use fallible_iterator::FallibleIterator; + + let db = Connection::open_in_memory()?; + db.create_window_function( + "sumint", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Sum, + )?; + db.execute_batch( + "CREATE TABLE t3(x, y); + INSERT INTO t3 VALUES('a', 4), + ('b', 5), + ('c', 3), + ('d', 8), + ('e', 1);", + )?; + + let mut stmt = db.prepare( + "SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM t3 ORDER BY x;", + )?; + + let results: Vec<(String, i64)> = stmt + .query([])? + .map(|row| Ok((row.get("x")?, row.get("sum_y")?))) + .collect()?; + let expected = vec![ + ("a".to_owned(), 9), + ("b".to_owned(), 12), + ("c".to_owned(), 16), + ("d".to_owned(), 12), + ("e".to_owned(), 9), + ]; + assert_eq!(expected, results); + Ok(()) + } +} diff --git a/src/hooks.rs b/src/hooks.rs new file mode 100644 index 0000000..5058a0c --- /dev/null +++ b/src/hooks.rs @@ -0,0 +1,815 @@ +//! Commit, Data Change and Rollback Notification Callbacks +#![allow(non_camel_case_types)] + +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe}; +use std::ptr; + +use crate::ffi; + +use crate::{Connection, InnerConnection}; + +/// Action Codes +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(i32)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms)] +pub enum Action { + /// Unsupported / unexpected action + UNKNOWN = -1, + /// DELETE command + SQLITE_DELETE = ffi::SQLITE_DELETE, + /// INSERT command + SQLITE_INSERT = ffi::SQLITE_INSERT, + /// UPDATE command + SQLITE_UPDATE = ffi::SQLITE_UPDATE, +} + +impl From for Action { + #[inline] + fn from(code: i32) -> Action { + match code { + ffi::SQLITE_DELETE => Action::SQLITE_DELETE, + ffi::SQLITE_INSERT => Action::SQLITE_INSERT, + ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE, + _ => Action::UNKNOWN, + } + } +} + +/// The context received by an authorizer hook. +/// +/// See for more info. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct AuthContext<'c> { + /// The action to be authorized. + pub action: AuthAction<'c>, + + /// The database name, if applicable. + pub database_name: Option<&'c str>, + + /// The inner-most trigger or view responsible for the access attempt. + /// `None` if the access attempt was made by top-level SQL code. + pub accessor: Option<&'c str>, +} + +/// Actions and arguments found within a statement during +/// preparation. +/// +/// See for more info. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +#[allow(missing_docs)] +pub enum AuthAction<'c> { + /// This variant is not normally produced by SQLite. You may encounter it + // if you're using a different version than what's supported by this library. + Unknown { + /// The unknown authorization action code. + code: i32, + /// The third arg to the authorizer callback. + arg1: Option<&'c str>, + /// The fourth arg to the authorizer callback. + arg2: Option<&'c str>, + }, + CreateIndex { + index_name: &'c str, + table_name: &'c str, + }, + CreateTable { + table_name: &'c str, + }, + CreateTempIndex { + index_name: &'c str, + table_name: &'c str, + }, + CreateTempTable { + table_name: &'c str, + }, + CreateTempTrigger { + trigger_name: &'c str, + table_name: &'c str, + }, + CreateTempView { + view_name: &'c str, + }, + CreateTrigger { + trigger_name: &'c str, + table_name: &'c str, + }, + CreateView { + view_name: &'c str, + }, + Delete { + table_name: &'c str, + }, + DropIndex { + index_name: &'c str, + table_name: &'c str, + }, + DropTable { + table_name: &'c str, + }, + DropTempIndex { + index_name: &'c str, + table_name: &'c str, + }, + DropTempTable { + table_name: &'c str, + }, + DropTempTrigger { + trigger_name: &'c str, + table_name: &'c str, + }, + DropTempView { + view_name: &'c str, + }, + DropTrigger { + trigger_name: &'c str, + table_name: &'c str, + }, + DropView { + view_name: &'c str, + }, + Insert { + table_name: &'c str, + }, + Pragma { + pragma_name: &'c str, + /// The pragma value, if present (e.g., `PRAGMA name = value;`). + pragma_value: Option<&'c str>, + }, + Read { + table_name: &'c str, + column_name: &'c str, + }, + Select, + Transaction { + operation: TransactionOperation, + }, + Update { + table_name: &'c str, + column_name: &'c str, + }, + Attach { + filename: &'c str, + }, + Detach { + database_name: &'c str, + }, + AlterTable { + database_name: &'c str, + table_name: &'c str, + }, + Reindex { + index_name: &'c str, + }, + Analyze { + table_name: &'c str, + }, + CreateVtable { + table_name: &'c str, + module_name: &'c str, + }, + DropVtable { + table_name: &'c str, + module_name: &'c str, + }, + Function { + function_name: &'c str, + }, + Savepoint { + operation: TransactionOperation, + savepoint_name: &'c str, + }, + #[cfg(feature = "modern_sqlite")] + Recursive, +} + +impl<'c> AuthAction<'c> { + fn from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self { + match (code, arg1, arg2) { + (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex { + index_name, + table_name, + }, + (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name }, + (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => { + Self::CreateTempIndex { + index_name, + table_name, + } + } + (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => { + Self::CreateTempTable { table_name } + } + (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::CreateTempTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => { + Self::CreateTempView { view_name } + } + (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::CreateTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name }, + (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name }, + (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex { + index_name, + table_name, + }, + (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name }, + (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => { + Self::DropTempIndex { + index_name, + table_name, + } + } + (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => { + Self::DropTempTable { table_name } + } + (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::DropTempTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name }, + (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger { + trigger_name, + table_name, + }, + (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name }, + (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name }, + (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma { + pragma_name, + pragma_value, + }, + (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read { + table_name, + column_name, + }, + (ffi::SQLITE_SELECT, ..) => Self::Select, + (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction { + operation: TransactionOperation::from_str(operation_str), + }, + (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update { + table_name, + column_name, + }, + (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename }, + (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name }, + (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable { + database_name, + table_name, + }, + (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name }, + (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name }, + (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => { + Self::CreateVtable { + table_name, + module_name, + } + } + (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable { + table_name, + module_name, + }, + (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name }, + (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint { + operation: TransactionOperation::from_str(operation_str), + savepoint_name, + }, + #[cfg(feature = "modern_sqlite")] // 3.8.3 + (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive, + (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 }, + } + } +} + +pub(crate) type BoxedAuthorizer = + Box FnMut(AuthContext<'c>) -> Authorization + Send + 'static>; + +/// A transaction operation. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +#[allow(missing_docs)] +pub enum TransactionOperation { + Unknown, + Begin, + Release, + Rollback, +} + +impl TransactionOperation { + fn from_str(op_str: &str) -> Self { + match op_str { + "BEGIN" => Self::Begin, + "RELEASE" => Self::Release, + "ROLLBACK" => Self::Rollback, + _ => Self::Unknown, + } + } +} + +/// [`authorizer`](Connection::authorizer) return code +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum Authorization { + /// Authorize the action. + Allow, + /// Don't allow access, but don't trigger an error either. + Ignore, + /// Trigger an error. + Deny, +} + +impl Authorization { + fn into_raw(self) -> c_int { + match self { + Self::Allow => ffi::SQLITE_OK, + Self::Ignore => ffi::SQLITE_IGNORE, + Self::Deny => ffi::SQLITE_DENY, + } + } +} + +impl Connection { + /// Register a callback function to be invoked whenever + /// a transaction is committed. + /// + /// The callback returns `true` to rollback. + #[inline] + pub fn commit_hook(&self, hook: Option) + where + F: FnMut() -> bool + Send + 'static, + { + self.db.borrow_mut().commit_hook(hook); + } + + /// Register a callback function to be invoked whenever + /// a transaction is committed. + #[inline] + pub fn rollback_hook(&self, hook: Option) + where + F: FnMut() + Send + 'static, + { + self.db.borrow_mut().rollback_hook(hook); + } + + /// Register a callback function to be invoked whenever + /// a row is updated, inserted or deleted in a rowid table. + /// + /// The callback parameters are: + /// + /// - the type of database update (`SQLITE_INSERT`, `SQLITE_UPDATE` or + /// `SQLITE_DELETE`), + /// - the name of the database ("main", "temp", ...), + /// - the name of the table that is updated, + /// - the ROWID of the row that is updated. + #[inline] + pub fn update_hook(&self, hook: Option) + where + F: FnMut(Action, &str, &str, i64) + Send + 'static, + { + self.db.borrow_mut().update_hook(hook); + } + + /// Register a query progress callback. + /// + /// The parameter `num_ops` is the approximate number of virtual machine + /// instructions that are evaluated between successive invocations of the + /// `handler`. If `num_ops` is less than one then the progress handler + /// is disabled. + /// + /// If the progress callback returns `true`, the operation is interrupted. + pub fn progress_handler(&self, num_ops: c_int, handler: Option) + where + F: FnMut() -> bool + Send + RefUnwindSafe + 'static, + { + self.db.borrow_mut().progress_handler(num_ops, handler); + } + + /// Register an authorizer callback that's invoked + /// as a statement is being prepared. + #[inline] + pub fn authorizer<'c, F>(&self, hook: Option) + where + F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static, + { + self.db.borrow_mut().authorizer(hook); + } +} + +impl InnerConnection { + #[inline] + pub fn remove_hooks(&mut self) { + self.update_hook(None::); + self.commit_hook(None:: bool>); + self.rollback_hook(None::); + self.progress_handler(0, None:: bool>); + self.authorizer(None::) -> Authorization>); + } + + fn commit_hook(&mut self, hook: Option) + where + F: FnMut() -> bool + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) -> c_int + where + F: FnMut() -> bool, + { + let r = catch_unwind(|| { + let boxed_hook: *mut F = p_arg.cast::(); + (*boxed_hook)() + }); + if let Ok(true) = r { + 1 + } else { + 0 + } + } + + // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with + // `sqlite3_commit_hook`. so we keep the `xDestroy` function in + // `InnerConnection.free_boxed_hook`. + let free_commit_hook = if hook.is_some() { + Some(free_boxed_hook:: as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_commit_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook.cast(), + ) + } + } + _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_commit_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_commit_hook = free_commit_hook; + } + + fn rollback_hook(&mut self, hook: Option) + where + F: FnMut() + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) + where + F: FnMut(), + { + drop(catch_unwind(|| { + let boxed_hook: *mut F = p_arg.cast::(); + (*boxed_hook)(); + })); + } + + let free_rollback_hook = if hook.is_some() { + Some(free_boxed_hook:: as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_rollback_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook.cast(), + ) + } + } + _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_rollback_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_rollback_hook = free_rollback_hook; + } + + fn update_hook(&mut self, hook: Option) + where + F: FnMut(Action, &str, &str, i64) + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure( + p_arg: *mut c_void, + action_code: c_int, + p_db_name: *const c_char, + p_table_name: *const c_char, + row_id: i64, + ) where + F: FnMut(Action, &str, &str, i64), + { + let action = Action::from(action_code); + drop(catch_unwind(|| { + let boxed_hook: *mut F = p_arg.cast::(); + (*boxed_hook)( + action, + expect_utf8(p_db_name, "database name"), + expect_utf8(p_table_name, "table name"), + row_id, + ); + })); + } + + let free_update_hook = if hook.is_some() { + Some(free_boxed_hook:: as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_update_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook.cast(), + ) + } + } + _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_update_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_update_hook = free_update_hook; + } + + fn progress_handler(&mut self, num_ops: c_int, handler: Option) + where + F: FnMut() -> bool + Send + RefUnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) -> c_int + where + F: FnMut() -> bool, + { + let r = catch_unwind(|| { + let boxed_handler: *mut F = p_arg.cast::(); + (*boxed_handler)() + }); + if let Ok(true) = r { + 1 + } else { + 0 + } + } + + if let Some(handler) = handler { + let boxed_handler = Box::new(handler); + unsafe { + ffi::sqlite3_progress_handler( + self.db(), + num_ops, + Some(call_boxed_closure::), + &*boxed_handler as *const F as *mut _, + ); + } + self.progress_handler = Some(boxed_handler); + } else { + unsafe { ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) } + self.progress_handler = None; + }; + } + + fn authorizer<'c, F>(&'c mut self, authorizer: Option) + where + F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure<'c, F>( + p_arg: *mut c_void, + action_code: c_int, + param1: *const c_char, + param2: *const c_char, + db_name: *const c_char, + trigger_or_view_name: *const c_char, + ) -> c_int + where + F: FnMut(AuthContext<'c>) -> Authorization + Send + 'static, + { + catch_unwind(|| { + let action = AuthAction::from_raw( + action_code, + expect_optional_utf8(param1, "authorizer param 1"), + expect_optional_utf8(param2, "authorizer param 2"), + ); + let auth_ctx = AuthContext { + action, + database_name: expect_optional_utf8(db_name, "database name"), + accessor: expect_optional_utf8( + trigger_or_view_name, + "accessor (inner-most trigger or view)", + ), + }; + let boxed_hook: *mut F = p_arg.cast::(); + (*boxed_hook)(auth_ctx) + }) + .map_or_else(|_| ffi::SQLITE_ERROR, Authorization::into_raw) + } + + let callback_fn = authorizer + .as_ref() + .map(|_| call_boxed_closure::<'c, F> as unsafe extern "C" fn(_, _, _, _, _, _) -> _); + let boxed_authorizer = authorizer.map(Box::new); + + match unsafe { + ffi::sqlite3_set_authorizer( + self.db(), + callback_fn, + boxed_authorizer + .as_ref() + .map_or_else(ptr::null_mut, |f| &**f as *const F as *mut _), + ) + } { + ffi::SQLITE_OK => { + self.authorizer = boxed_authorizer.map(|ba| ba as _); + } + err_code => { + // The only error that `sqlite3_set_authorizer` returns is `SQLITE_MISUSE` + // when compiled with `ENABLE_API_ARMOR` and the db pointer is invalid. + // This library does not allow constructing a null db ptr, so if this branch + // is hit, something very bad has happened. Panicking instead of returning + // `Result` keeps this hook's API consistent with the others. + panic!("unexpectedly failed to set_authorizer: {}", unsafe { + crate::error::error_from_handle(self.db(), err_code) + }); + } + } + } +} + +unsafe fn free_boxed_hook(p: *mut c_void) { + drop(Box::from_raw(p.cast::())); +} + +unsafe fn expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str { + expect_optional_utf8(p_str, description) + .unwrap_or_else(|| panic!("received empty {}", description)) +} + +unsafe fn expect_optional_utf8<'a>( + p_str: *const c_char, + description: &'static str, +) -> Option<&'a str> { + if p_str.is_null() { + return None; + } + std::str::from_utf8(std::ffi::CStr::from_ptr(p_str).to_bytes()) + .unwrap_or_else(|_| panic!("received non-utf8 string as {}", description)) + .into() +} + +#[cfg(test)] +mod test { + use super::Action; + use crate::{Connection, Result}; + use std::sync::atomic::{AtomicBool, Ordering}; + + #[test] + fn test_commit_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.commit_hook(Some(|| { + CALLED.store(true, Ordering::Relaxed); + false + })); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_fn_commit_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + fn hook() -> bool { + true + } + + db.commit_hook(Some(hook)); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap_err(); + Ok(()) + } + + #[test] + fn test_rollback_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.rollback_hook(Some(|| { + CALLED.store(true, Ordering::Relaxed); + })); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_update_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.update_hook(Some(|action, db: &str, tbl: &str, row_id| { + assert_eq!(Action::SQLITE_INSERT, action); + assert_eq!("main", db); + assert_eq!("foo", tbl); + assert_eq!(1, row_id); + CALLED.store(true, Ordering::Relaxed); + })); + db.execute_batch("CREATE TABLE foo (t TEXT)")?; + db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_progress_handler() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.progress_handler( + 1, + Some(|| { + CALLED.store(true, Ordering::Relaxed); + false + }), + ); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_progress_handler_interrupt() -> Result<()> { + let db = Connection::open_in_memory()?; + + fn handler() -> bool { + true + } + + db.progress_handler(1, Some(handler)); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap_err(); + Ok(()) + } + + #[test] + fn test_authorizer() -> Result<()> { + use super::{AuthAction, AuthContext, Authorization}; + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (public TEXT, private TEXT)") + .unwrap(); + + let authorizer = move |ctx: AuthContext<'_>| match ctx.action { + AuthAction::Read { column_name, .. } if column_name == "private" => { + Authorization::Ignore + } + AuthAction::DropTable { .. } => Authorization::Deny, + AuthAction::Pragma { .. } => panic!("shouldn't be called"), + _ => Authorization::Allow, + }; + + db.authorizer(Some(authorizer)); + db.execute_batch( + "BEGIN TRANSACTION; INSERT INTO foo VALUES ('pub txt', 'priv txt'); COMMIT;", + ) + .unwrap(); + db.query_row_and_then("SELECT * FROM foo", [], |row| -> Result<()> { + assert_eq!(row.get::<_, String>("public")?, "pub txt"); + assert!(row.get::<_, Option>("private")?.is_none()); + Ok(()) + }) + .unwrap(); + db.execute_batch("DROP TABLE foo").unwrap_err(); + + db.authorizer(None::) -> Authorization>); + db.execute_batch("PRAGMA user_version=1").unwrap(); // Disallowed by first authorizer, but it's now removed. + + Ok(()) + } +} diff --git a/src/inner_connection.rs b/src/inner_connection.rs new file mode 100644 index 0000000..e5bc3f1 --- /dev/null +++ b/src/inner_connection.rs @@ -0,0 +1,456 @@ +use std::ffi::CStr; +use std::os::raw::{c_char, c_int}; +#[cfg(feature = "load_extension")] +use std::path::Path; +use std::ptr; +use std::str; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use super::ffi; +use super::str_for_sqlite; +use super::{Connection, InterruptHandle, OpenFlags, Result}; +use crate::error::{error_from_handle, error_from_sqlite_code, error_with_offset, Error}; +use crate::raw_statement::RawStatement; +use crate::statement::Statement; +use crate::version::version_number; + +pub struct InnerConnection { + pub db: *mut ffi::sqlite3, + // It's unsafe to call `sqlite3_close` while another thread is performing + // a `sqlite3_interrupt`, and vice versa, so we take this mutex during + // those functions. This protects a copy of the `db` pointer (which is + // cleared on closing), however the main copy, `db`, is unprotected. + // Otherwise, a long running query would prevent calling interrupt, as + // interrupt would only acquire the lock after the query's completion. + interrupt_lock: Arc>, + #[cfg(feature = "hooks")] + pub free_commit_hook: Option, + #[cfg(feature = "hooks")] + pub free_rollback_hook: Option, + #[cfg(feature = "hooks")] + pub free_update_hook: Option, + #[cfg(feature = "hooks")] + pub progress_handler: Option bool + Send>>, + #[cfg(feature = "hooks")] + pub authorizer: Option, + owned: bool, +} + +unsafe impl Send for InnerConnection {} + +impl InnerConnection { + #[allow(clippy::mutex_atomic)] + #[inline] + pub unsafe fn new(db: *mut ffi::sqlite3, owned: bool) -> InnerConnection { + InnerConnection { + db, + interrupt_lock: Arc::new(Mutex::new(db)), + #[cfg(feature = "hooks")] + free_commit_hook: None, + #[cfg(feature = "hooks")] + free_rollback_hook: None, + #[cfg(feature = "hooks")] + free_update_hook: None, + #[cfg(feature = "hooks")] + progress_handler: None, + #[cfg(feature = "hooks")] + authorizer: None, + owned, + } + } + + pub fn open_with_flags( + c_path: &CStr, + flags: OpenFlags, + vfs: Option<&CStr>, + ) -> Result { + ensure_safe_sqlite_threading_mode()?; + + // Replicate the check for sane open flags from SQLite, because the check in + // SQLite itself wasn't added until version 3.7.3. + debug_assert_eq!(1 << OpenFlags::SQLITE_OPEN_READ_ONLY.bits, 0x02); + debug_assert_eq!(1 << OpenFlags::SQLITE_OPEN_READ_WRITE.bits, 0x04); + debug_assert_eq!( + 1 << (OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE).bits, + 0x40 + ); + if (1 << (flags.bits & 0x7)) & 0x46 == 0 { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + None, + )); + } + + let z_vfs = match vfs { + Some(c_vfs) => c_vfs.as_ptr(), + None => ptr::null(), + }; + + unsafe { + let mut db: *mut ffi::sqlite3 = ptr::null_mut(); + let r = ffi::sqlite3_open_v2(c_path.as_ptr(), &mut db, flags.bits(), z_vfs); + if r != ffi::SQLITE_OK { + let e = if db.is_null() { + error_from_sqlite_code(r, Some(c_path.to_string_lossy().to_string())) + } else { + let mut e = error_from_handle(db, r); + if let Error::SqliteFailure( + ffi::Error { + code: ffi::ErrorCode::CannotOpen, + .. + }, + Some(msg), + ) = e + { + e = Error::SqliteFailure( + ffi::Error::new(r), + Some(format!("{}: {}", msg, c_path.to_string_lossy())), + ); + } + ffi::sqlite3_close(db); + e + }; + + return Err(e); + } + + // attempt to turn on extended results code; don't fail if we can't. + ffi::sqlite3_extended_result_codes(db, 1); + + let r = ffi::sqlite3_busy_timeout(db, 5000); + if r != ffi::SQLITE_OK { + let e = error_from_handle(db, r); + ffi::sqlite3_close(db); + return Err(e); + } + + Ok(InnerConnection::new(db, true)) + } + } + + #[inline] + pub fn db(&self) -> *mut ffi::sqlite3 { + self.db + } + + #[inline] + pub fn decode_result(&self, code: c_int) -> Result<()> { + unsafe { InnerConnection::decode_result_raw(self.db(), code) } + } + + #[inline] + unsafe fn decode_result_raw(db: *mut ffi::sqlite3, code: c_int) -> Result<()> { + if code == ffi::SQLITE_OK { + Ok(()) + } else { + Err(error_from_handle(db, code)) + } + } + + #[allow(clippy::mutex_atomic)] + pub fn close(&mut self) -> Result<()> { + if self.db.is_null() { + return Ok(()); + } + self.remove_hooks(); + let mut shared_handle = self.interrupt_lock.lock().unwrap(); + assert!( + !shared_handle.is_null(), + "Bug: Somehow interrupt_lock was cleared before the DB was closed" + ); + if !self.owned { + self.db = ptr::null_mut(); + return Ok(()); + } + unsafe { + let r = ffi::sqlite3_close(self.db); + // Need to use _raw because _guard has a reference out, and + // decode_result takes &mut self. + let r = InnerConnection::decode_result_raw(self.db, r); + if r.is_ok() { + *shared_handle = ptr::null_mut(); + self.db = ptr::null_mut(); + } + r + } + } + + #[inline] + pub fn get_interrupt_handle(&self) -> InterruptHandle { + InterruptHandle { + db_lock: Arc::clone(&self.interrupt_lock), + } + } + + #[inline] + #[cfg(feature = "load_extension")] + pub unsafe fn enable_load_extension(&mut self, onoff: c_int) -> Result<()> { + let r = ffi::sqlite3_enable_load_extension(self.db, onoff); + self.decode_result(r) + } + + #[cfg(feature = "load_extension")] + pub unsafe fn load_extension( + &self, + dylib_path: &Path, + entry_point: Option<&str>, + ) -> Result<()> { + let dylib_str = super::path_to_cstring(dylib_path)?; + let mut errmsg: *mut c_char = ptr::null_mut(); + let r = if let Some(entry_point) = entry_point { + let c_entry = crate::str_to_cstring(entry_point)?; + ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), c_entry.as_ptr(), &mut errmsg) + } else { + ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), ptr::null(), &mut errmsg) + }; + if r == ffi::SQLITE_OK { + Ok(()) + } else { + let message = super::errmsg_to_string(errmsg); + ffi::sqlite3_free(errmsg.cast::()); + Err(error_from_sqlite_code(r, Some(message))) + } + } + + #[inline] + pub fn last_insert_rowid(&self) -> i64 { + unsafe { ffi::sqlite3_last_insert_rowid(self.db()) } + } + + pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result> { + let mut c_stmt = ptr::null_mut(); + let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?; + let mut c_tail = ptr::null(); + // TODO sqlite3_prepare_v3 (https://sqlite.org/c3ref/c_prepare_normalize.html) // 3.20.0, #728 + #[cfg(not(feature = "unlock_notify"))] + let r = unsafe { + ffi::sqlite3_prepare_v2( + self.db(), + c_sql, + len, + &mut c_stmt as *mut *mut ffi::sqlite3_stmt, + &mut c_tail as *mut *const c_char, + ) + }; + #[cfg(feature = "unlock_notify")] + let r = unsafe { + use crate::unlock_notify; + let mut rc; + loop { + rc = ffi::sqlite3_prepare_v2( + self.db(), + c_sql, + len, + &mut c_stmt as *mut *mut ffi::sqlite3_stmt, + &mut c_tail as *mut *const c_char, + ); + if !unlock_notify::is_locked(self.db, rc) { + break; + } + rc = unlock_notify::wait_for_unlock_notify(self.db); + if rc != ffi::SQLITE_OK { + break; + } + } + rc + }; + // If there is an error, *ppStmt is set to NULL. + if r != ffi::SQLITE_OK { + return Err(unsafe { error_with_offset(self.db, r, sql) }); + } + // If the input text contains no SQL (if the input is an empty string or a + // comment) then *ppStmt is set to NULL. + let c_stmt: *mut ffi::sqlite3_stmt = c_stmt; + let c_tail: *const c_char = c_tail; + let tail = if c_tail.is_null() { + 0 + } else { + let n = (c_tail as isize) - (c_sql as isize); + if n <= 0 || n >= len as isize { + 0 + } else { + n as usize + } + }; + Ok(Statement::new(conn, unsafe { + RawStatement::new(c_stmt, tail) + })) + } + + #[inline] + pub fn changes(&self) -> u64 { + #[cfg(not(feature = "modern_sqlite"))] + unsafe { + ffi::sqlite3_changes(self.db()) as u64 + } + #[cfg(feature = "modern_sqlite")] // 3.37.0 + unsafe { + ffi::sqlite3_changes64(self.db()) as u64 + } + } + + #[inline] + pub fn is_autocommit(&self) -> bool { + unsafe { ffi::sqlite3_get_autocommit(self.db()) != 0 } + } + + #[cfg(feature = "modern_sqlite")] // 3.8.6 + pub fn is_busy(&self) -> bool { + let db = self.db(); + unsafe { + let mut stmt = ffi::sqlite3_next_stmt(db, ptr::null_mut()); + while !stmt.is_null() { + if ffi::sqlite3_stmt_busy(stmt) != 0 { + return true; + } + stmt = ffi::sqlite3_next_stmt(db, stmt); + } + } + false + } + + #[cfg(feature = "modern_sqlite")] // 3.10.0 + pub fn cache_flush(&mut self) -> Result<()> { + crate::error::check(unsafe { ffi::sqlite3_db_cacheflush(self.db()) }) + } + + #[cfg(not(feature = "hooks"))] + #[inline] + fn remove_hooks(&mut self) {} + + #[cfg(feature = "modern_sqlite")] // 3.7.11 + pub fn db_readonly(&self, db_name: super::DatabaseName<'_>) -> Result { + let name = db_name.as_cstring()?; + let r = unsafe { ffi::sqlite3_db_readonly(self.db, name.as_ptr()) }; + match r { + 0 => Ok(false), + 1 => Ok(true), + -1 => Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("{:?} is not the name of a database", db_name)), + )), + _ => Err(error_from_sqlite_code( + r, + Some("Unexpected result".to_owned()), + )), + } + } + + #[cfg(feature = "modern_sqlite")] // 3.37.0 + pub fn txn_state( + &self, + db_name: Option>, + ) -> Result { + let r = if let Some(ref name) = db_name { + let name = name.as_cstring()?; + unsafe { ffi::sqlite3_txn_state(self.db, name.as_ptr()) } + } else { + unsafe { ffi::sqlite3_txn_state(self.db, ptr::null()) } + }; + match r { + 0 => Ok(super::transaction::TransactionState::None), + 1 => Ok(super::transaction::TransactionState::Read), + 2 => Ok(super::transaction::TransactionState::Write), + -1 => Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("{:?} is not the name of a valid schema", db_name)), + )), + _ => Err(error_from_sqlite_code( + r, + Some("Unexpected result".to_owned()), + )), + } + } + + #[inline] + #[cfg(feature = "release_memory")] + pub fn release_memory(&self) -> Result<()> { + self.decode_result(unsafe { ffi::sqlite3_db_release_memory(self.db) }) + } +} + +impl Drop for InnerConnection { + #[allow(unused_must_use)] + #[inline] + fn drop(&mut self) { + use std::thread::panicking; + + if let Err(e) = self.close() { + if panicking() { + eprintln!("Error while closing SQLite connection: {:?}", e); + } else { + panic!("Error while closing SQLite connection: {:?}", e); + } + } + } +} + +#[cfg(not(any(target_arch = "wasm32")))] +static SQLITE_INIT: std::sync::Once = std::sync::Once::new(); + +pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false); + +// threading mode checks are not necessary (and do not work) on target +// platforms that do not have threading (such as webassembly) +#[cfg(any(target_arch = "wasm32"))] +fn ensure_safe_sqlite_threading_mode() -> Result<()> { + Ok(()) +} + +#[cfg(not(any(target_arch = "wasm32")))] +fn ensure_safe_sqlite_threading_mode() -> Result<()> { + // Ensure SQLite was compiled in threadsafe mode. + if unsafe { ffi::sqlite3_threadsafe() == 0 } { + return Err(Error::SqliteSingleThreadedMode); + } + + // Now we know SQLite is _capable_ of being in Multi-thread of Serialized mode, + // but it's possible someone configured it to be in Single-thread mode + // before calling into us. That would mean we're exposing an unsafe API via + // a safe one (in Rust terminology), which is no good. We have two options + // to protect against this, depending on the version of SQLite we're linked + // with: + // + // 1. If we're on 3.7.0 or later, we can ask SQLite for a mutex and check for + // the magic value 8. This isn't documented, but it's what SQLite + // returns for its mutex allocation function in Single-thread mode. + // 2. If we're prior to SQLite 3.7.0, AFAIK there's no way to check the + // threading mode. The check we perform for >= 3.7.0 will segfault. + // Instead, we insist on being able to call sqlite3_config and + // sqlite3_initialize ourself, ensuring we know the threading + // mode. This will fail if someone else has already initialized SQLite + // even if they initialized it safely. That's not ideal either, which is + // why we expose bypass_sqlite_initialization above. + if version_number() >= 3_007_000 { + const SQLITE_SINGLETHREADED_MUTEX_MAGIC: usize = 8; + let is_singlethreaded = unsafe { + let mutex_ptr = ffi::sqlite3_mutex_alloc(0); + let is_singlethreaded = mutex_ptr as usize == SQLITE_SINGLETHREADED_MUTEX_MAGIC; + ffi::sqlite3_mutex_free(mutex_ptr); + is_singlethreaded + }; + if is_singlethreaded { + Err(Error::SqliteSingleThreadedMode) + } else { + Ok(()) + } + } else { + SQLITE_INIT.call_once(|| { + if BYPASS_SQLITE_INIT.load(Ordering::Relaxed) { + return; + } + + unsafe { + assert!(ffi::sqlite3_config(ffi::SQLITE_CONFIG_MULTITHREAD) == ffi::SQLITE_OK && ffi::sqlite3_initialize() == ffi::SQLITE_OK, + "Could not ensure safe initialization of SQLite.\n\ + To fix this, either:\n\ + * Upgrade SQLite to at least version 3.7.0\n\ + * Ensure that SQLite has been initialized in Multi-thread or Serialized mode and call\n\ + rusqlite::bypass_sqlite_initialization() prior to your first connection attempt." + ); + } + }); + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..89f133e --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2127 @@ +//! Rusqlite is an ergonomic wrapper for using SQLite from Rust. +//! +//! Historically, the API was based on the one from +//! [`rust-postgres`](https://github.com/sfackler/rust-postgres). However, the +//! two have diverged in many ways, and no compatibility between the two is +//! intended. +//! +//! ```rust +//! use rusqlite::{params, Connection, Result}; +//! +//! #[derive(Debug)] +//! struct Person { +//! id: i32, +//! name: String, +//! data: Option>, +//! } +//! +//! fn main() -> Result<()> { +//! let conn = Connection::open_in_memory()?; +//! +//! conn.execute( +//! "CREATE TABLE person ( +//! id INTEGER PRIMARY KEY, +//! name TEXT NOT NULL, +//! data BLOB +//! )", +//! (), // empty list of parameters. +//! )?; +//! let me = Person { +//! id: 0, +//! name: "Steven".to_string(), +//! data: None, +//! }; +//! conn.execute( +//! "INSERT INTO person (name, data) VALUES (?1, ?2)", +//! (&me.name, &me.data), +//! )?; +//! +//! let mut stmt = conn.prepare("SELECT id, name, data FROM person")?; +//! let person_iter = stmt.query_map([], |row| { +//! Ok(Person { +//! id: row.get(0)?, +//! name: row.get(1)?, +//! data: row.get(2)?, +//! }) +//! })?; +//! +//! for person in person_iter { +//! println!("Found person {:?}", person.unwrap()); +//! } +//! Ok(()) +//! } +//! ``` +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_cfg))] + +pub use libsqlite3_sys as ffi; + +use std::cell::RefCell; +use std::default::Default; +use std::ffi::{CStr, CString}; +use std::fmt; +use std::os::raw::{c_char, c_int}; + +use std::path::{Path, PathBuf}; +use std::result; +use std::str; +use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex}; + +use crate::cache::StatementCache; +use crate::inner_connection::{InnerConnection, BYPASS_SQLITE_INIT}; +use crate::raw_statement::RawStatement; +use crate::types::ValueRef; + +pub use crate::cache::CachedStatement; +pub use crate::column::Column; +pub use crate::error::Error; +pub use crate::ffi::ErrorCode; +#[cfg(feature = "load_extension")] +pub use crate::load_extension_guard::LoadExtensionGuard; +pub use crate::params::{params_from_iter, Params, ParamsFromIter}; +pub use crate::row::{AndThenRows, Map, MappedRows, Row, RowIndex, Rows}; +pub use crate::statement::{Statement, StatementStatus}; +pub use crate::transaction::{DropBehavior, Savepoint, Transaction, TransactionBehavior}; +pub use crate::types::ToSql; +pub use crate::version::*; + +mod error; + +#[cfg(feature = "backup")] +#[cfg_attr(docsrs, doc(cfg(feature = "backup")))] +pub mod backup; +#[cfg(feature = "blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "blob")))] +pub mod blob; +mod busy; +mod cache; +#[cfg(feature = "collation")] +#[cfg_attr(docsrs, doc(cfg(feature = "collation")))] +mod collation; +mod column; +pub mod config; +#[cfg(any(feature = "functions", feature = "vtab"))] +mod context; +#[cfg(feature = "functions")] +#[cfg_attr(docsrs, doc(cfg(feature = "functions")))] +pub mod functions; +#[cfg(feature = "hooks")] +#[cfg_attr(docsrs, doc(cfg(feature = "hooks")))] +pub mod hooks; +mod inner_connection; +#[cfg(feature = "limits")] +#[cfg_attr(docsrs, doc(cfg(feature = "limits")))] +pub mod limits; +#[cfg(feature = "load_extension")] +mod load_extension_guard; +mod params; +mod pragma; +mod raw_statement; +mod row; +#[cfg(feature = "session")] +#[cfg_attr(docsrs, doc(cfg(feature = "session")))] +pub mod session; +mod statement; +#[cfg(feature = "trace")] +#[cfg_attr(docsrs, doc(cfg(feature = "trace")))] +pub mod trace; +mod transaction; +pub mod types; +#[cfg(feature = "unlock_notify")] +mod unlock_notify; +mod version; +#[cfg(feature = "vtab")] +#[cfg_attr(docsrs, doc(cfg(feature = "vtab")))] +pub mod vtab; + +pub(crate) mod util; +pub(crate) use util::SmallCString; + +// Number of cached prepared statements we'll hold on to. +const STATEMENT_CACHE_DEFAULT_CAPACITY: usize = 16; +/// To be used when your statement has no [parameter][sqlite-varparam]. +/// +/// [sqlite-varparam]: https://sqlite.org/lang_expr.html#varparam +/// +/// This is deprecated in favor of using an empty array literal. +#[deprecated = "Use an empty array instead; `stmt.execute(NO_PARAMS)` => `stmt.execute([])`"] +pub const NO_PARAMS: &[&dyn ToSql] = &[]; + +/// A macro making it more convenient to longer lists of +/// parameters as a `&[&dyn ToSql]`. +/// +/// # Example +/// +/// ```rust,no_run +/// # use rusqlite::{Result, Connection, params}; +/// +/// struct Person { +/// name: String, +/// age_in_years: u8, +/// data: Option>, +/// } +/// +/// fn add_person(conn: &Connection, person: &Person) -> Result<()> { +/// conn.execute( +/// "INSERT INTO person(name, age_in_years, data) VALUES (?1, ?2, ?3)", +/// params![person.name, person.age_in_years, person.data], +/// )?; +/// Ok(()) +/// } +/// ``` +#[macro_export] +macro_rules! params { + () => { + &[] as &[&dyn $crate::ToSql] + }; + ($($param:expr),+ $(,)?) => { + &[$(&$param as &dyn $crate::ToSql),+] as &[&dyn $crate::ToSql] + }; +} + +/// A macro making it more convenient to pass lists of named parameters +/// as a `&[(&str, &dyn ToSql)]`. +/// +/// # Example +/// +/// ```rust,no_run +/// # use rusqlite::{Result, Connection, named_params}; +/// +/// struct Person { +/// name: String, +/// age_in_years: u8, +/// data: Option>, +/// } +/// +/// fn add_person(conn: &Connection, person: &Person) -> Result<()> { +/// conn.execute( +/// "INSERT INTO person (name, age_in_years, data) +/// VALUES (:name, :age, :data)", +/// named_params! { +/// ":name": person.name, +/// ":age": person.age_in_years, +/// ":data": person.data, +/// }, +/// )?; +/// Ok(()) +/// } +/// ``` +#[macro_export] +macro_rules! named_params { + () => { + &[] as &[(&str, &dyn $crate::ToSql)] + }; + // Note: It's a lot more work to support this as part of the same macro as + // `params!`, unfortunately. + ($($param_name:literal: $param_val:expr),+ $(,)?) => { + &[$(($param_name, &$param_val as &dyn $crate::ToSql)),+] as &[(&str, &dyn $crate::ToSql)] + }; +} + +/// A typedef of the result returned by many methods. +pub type Result = result::Result; + +/// See the [method documentation](#tymethod.optional). +pub trait OptionalExtension { + /// Converts a `Result` into a `Result>`. + /// + /// By default, Rusqlite treats 0 rows being returned from a query that is + /// expected to return 1 row as an error. This method will + /// handle that error, and give you back an `Option` instead. + fn optional(self) -> Result>; +} + +impl OptionalExtension for Result { + fn optional(self) -> Result> { + match self { + Ok(value) => Ok(Some(value)), + Err(Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e), + } + } +} + +unsafe fn errmsg_to_string(errmsg: *const c_char) -> String { + let c_slice = CStr::from_ptr(errmsg).to_bytes(); + String::from_utf8_lossy(c_slice).into_owned() +} + +fn str_to_cstring(s: &str) -> Result { + Ok(SmallCString::new(s)?) +} + +/// Returns `Ok((string ptr, len as c_int, SQLITE_STATIC | SQLITE_TRANSIENT))` +/// normally. +/// Returns error if the string is too large for sqlite. +/// The `sqlite3_destructor_type` item is always `SQLITE_TRANSIENT` unless +/// the string was empty (in which case it's `SQLITE_STATIC`, and the ptr is +/// static). +fn str_for_sqlite(s: &[u8]) -> Result<(*const c_char, c_int, ffi::sqlite3_destructor_type)> { + let len = len_as_c_int(s.len())?; + let (ptr, dtor_info) = if len != 0 { + (s.as_ptr().cast::(), ffi::SQLITE_TRANSIENT()) + } else { + // Return a pointer guaranteed to live forever + ("".as_ptr().cast::(), ffi::SQLITE_STATIC()) + }; + Ok((ptr, len, dtor_info)) +} + +// Helper to cast to c_int safely, returning the correct error type if the cast +// failed. +fn len_as_c_int(len: usize) -> Result { + if len >= (c_int::MAX as usize) { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_TOOBIG), + None, + )) + } else { + Ok(len as c_int) + } +} + +#[cfg(unix)] +fn path_to_cstring(p: &Path) -> Result { + use std::os::unix::ffi::OsStrExt; + Ok(CString::new(p.as_os_str().as_bytes())?) +} + +#[cfg(not(unix))] +fn path_to_cstring(p: &Path) -> Result { + let s = p.to_str().ok_or_else(|| Error::InvalidPath(p.to_owned()))?; + Ok(CString::new(s)?) +} + +/// Name for a database within a SQLite connection. +#[derive(Copy, Clone, Debug)] +pub enum DatabaseName<'a> { + /// The main database. + Main, + + /// The temporary database (e.g., any "CREATE TEMPORARY TABLE" tables). + Temp, + + /// A database that has been attached via "ATTACH DATABASE ...". + Attached(&'a str), +} + +/// Shorthand for [`DatabaseName::Main`]. +pub const MAIN_DB: DatabaseName<'static> = DatabaseName::Main; + +/// Shorthand for [`DatabaseName::Temp`]. +pub const TEMP_DB: DatabaseName<'static> = DatabaseName::Temp; + +// Currently DatabaseName is only used by the backup and blob mods, so hide +// this (private) impl to avoid dead code warnings. +#[cfg(any( + feature = "backup", + feature = "blob", + feature = "session", + feature = "modern_sqlite" +))] +impl DatabaseName<'_> { + #[inline] + fn as_cstring(&self) -> Result { + use self::DatabaseName::{Attached, Main, Temp}; + match *self { + Main => str_to_cstring("main"), + Temp => str_to_cstring("temp"), + Attached(s) => str_to_cstring(s), + } + } +} + +/// A connection to a SQLite database. +pub struct Connection { + db: RefCell, + cache: StatementCache, + path: Option, +} + +unsafe impl Send for Connection {} + +impl Drop for Connection { + #[inline] + fn drop(&mut self) { + self.flush_prepared_statement_cache(); + } +} + +impl Connection { + /// Open a new connection to a SQLite database. If a database does not exist + /// at the path, one is created. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn open_my_db() -> Result<()> { + /// let path = "./my_db.db3"; + /// let db = Connection::open(path)?; + /// // Use the database somehow... + /// println!("{}", db.is_autocommit()); + /// Ok(()) + /// } + /// ``` + /// + /// # Flags + /// + /// `Connection::open(path)` is equivalent to using + /// [`Connection::open_with_flags`] with the default [`OpenFlags`]. That is, + /// it's equivalent to: + /// + /// ```ignore + /// Connection::open_with_flags( + /// path, + /// OpenFlags::SQLITE_OPEN_READ_WRITE + /// | OpenFlags::SQLITE_OPEN_CREATE + /// | OpenFlags::SQLITE_OPEN_URI + /// | OpenFlags::SQLITE_OPEN_NO_MUTEX, + /// ) + /// ``` + /// + /// These flags have the following effects: + /// + /// - Open the database for both reading or writing. + /// - Create the database if one does not exist at the path. + /// - Allow the filename to be interpreted as a URI (see + /// for details). + /// - Disables the use of a per-connection mutex. + /// + /// Rusqlite enforces thread-safety at compile time, so additional + /// locking is not needed and provides no benefit. (See the + /// documentation on [`OpenFlags::SQLITE_OPEN_FULL_MUTEX`] for some + /// additional discussion about this). + /// + /// Most of these are also the default settings for the C API, although + /// technically the default locking behavior is controlled by the flags used + /// when compiling SQLite -- rather than let it vary, we choose `NO_MUTEX` + /// because it's a fairly clearly the best choice for users of this library. + /// + /// # Failure + /// + /// Will return `Err` if `path` cannot be converted to a C-compatible string + /// or if the underlying SQLite open call fails. + #[inline] + pub fn open>(path: P) -> Result { + let flags = OpenFlags::default(); + Connection::open_with_flags(path, flags) + } + + /// Open a new connection to an in-memory SQLite database. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite open call fails. + #[inline] + pub fn open_in_memory() -> Result { + let flags = OpenFlags::default(); + Connection::open_in_memory_with_flags(flags) + } + + /// Open a new connection to a SQLite database. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if `path` cannot be converted to a C-compatible + /// string or if the underlying SQLite open call fails. + #[inline] + pub fn open_with_flags>(path: P, flags: OpenFlags) -> Result { + let c_path = path_to_cstring(path.as_ref())?; + InnerConnection::open_with_flags(&c_path, flags, None).map(|db| Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + path: Some(path.as_ref().to_path_buf()), + }) + } + + /// Open a new connection to a SQLite database using the specific flags and + /// vfs name. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if either `path` or `vfs` cannot be converted to a + /// C-compatible string or if the underlying SQLite open call fails. + #[inline] + pub fn open_with_flags_and_vfs>( + path: P, + flags: OpenFlags, + vfs: &str, + ) -> Result { + let c_path = path_to_cstring(path.as_ref())?; + let c_vfs = str_to_cstring(vfs)?; + InnerConnection::open_with_flags(&c_path, flags, Some(&c_vfs)).map(|db| Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + path: Some(path.as_ref().to_path_buf()), + }) + } + + /// Open a new connection to an in-memory SQLite database. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite open call fails. + #[inline] + pub fn open_in_memory_with_flags(flags: OpenFlags) -> Result { + Connection::open_with_flags(":memory:", flags) + } + + /// Open a new connection to an in-memory SQLite database using the specific + /// flags and vfs name. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if `vfs` cannot be converted to a C-compatible + /// string or if the underlying SQLite open call fails. + #[inline] + pub fn open_in_memory_with_flags_and_vfs(flags: OpenFlags, vfs: &str) -> Result { + Connection::open_with_flags_and_vfs(":memory:", flags, vfs) + } + + /// Convenience method to run multiple SQL statements (that cannot take any + /// parameters). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn create_tables(conn: &Connection) -> Result<()> { + /// conn.execute_batch( + /// "BEGIN; + /// CREATE TABLE foo(x INTEGER); + /// CREATE TABLE bar(y TEXT); + /// COMMIT;", + /// ) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn execute_batch(&self, sql: &str) -> Result<()> { + let mut sql = sql; + while !sql.is_empty() { + let stmt = self.prepare(sql)?; + if !stmt.stmt.is_null() && stmt.step()? && cfg!(feature = "extra_check") { + // Some PRAGMA may return rows + return Err(Error::ExecuteReturnedResults); + } + let tail = stmt.stmt.tail(); + if tail == 0 || tail >= sql.len() { + break; + } + sql = &sql[tail..]; + } + Ok(()) + } + + /// Convenience method to prepare and execute a single SQL statement. + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// ## Example + /// + /// ### With positional params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection}; + /// fn update_rows(conn: &Connection) { + /// match conn.execute("UPDATE foo SET bar = 'baz' WHERE qux = ?", [1i32]) { + /// Ok(updated) => println!("{} rows were updated", updated), + /// Err(err) => println!("update failed: {}", err), + /// } + /// } + /// ``` + /// + /// ### With positional params of varying types + /// + /// ```rust,no_run + /// # use rusqlite::{params, Connection}; + /// fn update_rows(conn: &Connection) { + /// match conn.execute( + /// "UPDATE foo SET bar = 'baz' WHERE qux = ?1 AND quux = ?2", + /// params![1i32, 1.5f64], + /// ) { + /// Ok(updated) => println!("{} rows were updated", updated), + /// Err(err) => println!("update failed: {}", err), + /// } + /// } + /// ``` + /// + /// ### With named params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert(conn: &Connection) -> Result { + /// conn.execute( + /// "INSERT INTO test (name) VALUES (:name)", + /// &[(":name", "one")], + /// ) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn execute(&self, sql: &str, params: P) -> Result { + self.prepare(sql) + .and_then(|mut stmt| stmt.check_no_tail().and_then(|_| stmt.execute(params))) + } + + /// Returns the path to the database file, if one exists and is known. + /// + /// Note that in some cases [PRAGMA + /// database_list](https://sqlite.org/pragma.html#pragma_database_list) is + /// likely to be more robust. + #[inline] + pub fn path(&self) -> Option<&Path> { + self.path.as_deref() + } + + /// Attempts to free as much heap memory as possible from the database + /// connection. + /// + /// This calls [`sqlite3_db_release_memory`](https://www.sqlite.org/c3ref/db_release_memory.html). + #[inline] + #[cfg(feature = "release_memory")] + pub fn release_memory(&self) -> Result<()> { + self.db.borrow_mut().release_memory() + } + + /// Convenience method to prepare and execute a single SQL statement with + /// named parameter(s). + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[deprecated = "You can use `execute` with named params now."] + pub fn execute_named(&self, sql: &str, params: &[(&str, &dyn ToSql)]) -> Result { + // This function itself is deprecated, so it's fine + #![allow(deprecated)] + self.prepare(sql).and_then(|mut stmt| { + stmt.check_no_tail() + .and_then(|_| stmt.execute_named(params)) + }) + } + + /// Get the SQLite rowid of the most recent successful INSERT. + /// + /// Uses [sqlite3_last_insert_rowid](https://www.sqlite.org/c3ref/last_insert_rowid.html) under + /// the hood. + #[inline] + pub fn last_insert_rowid(&self) -> i64 { + self.db.borrow_mut().last_insert_rowid() + } + + /// Convenience method to execute a query that is expected to return a + /// single row. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Result, Connection}; + /// fn preferred_locale(conn: &Connection) -> Result { + /// conn.query_row( + /// "SELECT value FROM preferences WHERE name='locale'", + /// [], + /// |row| row.get(0), + /// ) + /// } + /// ``` + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call `.optional()` on the result of + /// this to get a `Result>`. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn query_row(&self, sql: &str, params: P, f: F) -> Result + where + P: Params, + F: FnOnce(&Row<'_>) -> Result, + { + let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; + stmt.query_row(params, f) + } + + /// Convenience method to execute a query with named parameter(s) that is + /// expected to return a single row. + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call `.optional()` on the result of + /// this to get a `Result>`. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[deprecated = "You can use `query_row` with named params now."] + pub fn query_row_named(&self, sql: &str, params: &[(&str, &dyn ToSql)], f: F) -> Result + where + F: FnOnce(&Row<'_>) -> Result, + { + self.query_row(sql, params, f) + } + + /// Convenience method to execute a query that is expected to return a + /// single row, and execute a mapping via `f` on that returned row with + /// the possibility of failure. The `Result` type of `f` must implement + /// `std::convert::From`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Result, Connection}; + /// fn preferred_locale(conn: &Connection) -> Result { + /// conn.query_row_and_then( + /// "SELECT value FROM preferences WHERE name='locale'", + /// [], + /// |row| row.get(0), + /// ) + /// } + /// ``` + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn query_row_and_then(&self, sql: &str, params: P, f: F) -> Result + where + P: Params, + F: FnOnce(&Row<'_>) -> Result, + E: From, + { + let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; + let mut rows = stmt.query(params)?; + + rows.get_expected_row().map_err(E::from).and_then(f) + } + + /// Prepare a SQL statement for execution. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert_new_people(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("INSERT INTO People (name) VALUES (?)")?; + /// stmt.execute(["Joe Smith"])?; + /// stmt.execute(["Bob Jones"])?; + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn prepare(&self, sql: &str) -> Result> { + self.db.borrow_mut().prepare(self, sql) + } + + /// Close the SQLite connection. + /// + /// This is functionally equivalent to the `Drop` implementation for + /// `Connection` except that on failure, it returns an error and the + /// connection itself (presumably so closing can be attempted again). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn close(self) -> Result<(), (Connection, Error)> { + self.flush_prepared_statement_cache(); + let r = self.db.borrow_mut().close(); + r.map_err(move |err| (self, err)) + } + + /// Enable loading of SQLite extensions from both SQL queries and Rust. + /// + /// You must call [`Connection::load_extension_disable`] when you're + /// finished loading extensions (failure to call it can lead to bad things, + /// see "Safety"), so you should strongly consider using + /// [`LoadExtensionGuard`] instead of this function, automatically disables + /// extension loading when it goes out of scope. + /// + /// # Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn load_my_extension(conn: &Connection) -> Result<()> { + /// // Safety: We fully trust the loaded extension and execute no untrusted SQL + /// // while extension loading is enabled. + /// unsafe { + /// conn.load_extension_enable()?; + /// let r = conn.load_extension("my/trusted/extension", None); + /// conn.load_extension_disable()?; + /// r + /// } + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + /// + /// # Safety + /// + /// TLDR: Don't execute any untrusted queries between this call and + /// [`Connection::load_extension_disable`]. + /// + /// Perhaps surprisingly, this function does not only allow the use of + /// [`Connection::load_extension`] from Rust, but it also allows SQL queries + /// to perform [the same operation][loadext]. For example, in the period + /// between `load_extension_enable` and `load_extension_disable`, the + /// following operation will load and call some function in some dynamic + /// library: + /// + /// ```sql + /// SELECT load_extension('why_is_this_possible.dll', 'dubious_func'); + /// ``` + /// + /// This means that while this is enabled a carefully crafted SQL query can + /// be used to escalate a SQL injection attack into code execution. + /// + /// Safely using this function requires that you trust all SQL queries run + /// between when it is called, and when loading is disabled (by + /// [`Connection::load_extension_disable`]). + /// + /// [loadext]: https://www.sqlite.org/lang_corefunc.html#load_extension + #[cfg(feature = "load_extension")] + #[cfg_attr(docsrs, doc(cfg(feature = "load_extension")))] + #[inline] + pub unsafe fn load_extension_enable(&self) -> Result<()> { + self.db.borrow_mut().enable_load_extension(1) + } + + /// Disable loading of SQLite extensions. + /// + /// See [`Connection::load_extension_enable`] for an example. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[cfg(feature = "load_extension")] + #[cfg_attr(docsrs, doc(cfg(feature = "load_extension")))] + #[inline] + pub fn load_extension_disable(&self) -> Result<()> { + // It's always safe to turn off extension loading. + unsafe { self.db.borrow_mut().enable_load_extension(0) } + } + + /// Load the SQLite extension at `dylib_path`. `dylib_path` is passed + /// through to `sqlite3_load_extension`, which may attempt OS-specific + /// modifications if the file cannot be loaded directly (for example + /// converting `"some/ext"` to `"some/ext.so"`, `"some\\ext.dll"`, ...). + /// + /// If `entry_point` is `None`, SQLite will attempt to find the entry point. + /// If it is not `None`, the entry point will be passed through to + /// `sqlite3_load_extension`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, LoadExtensionGuard}; + /// fn load_my_extension(conn: &Connection) -> Result<()> { + /// // Safety: we don't execute any SQL statements while + /// // extension loading is enabled. + /// let _guard = unsafe { LoadExtensionGuard::new(conn)? }; + /// // Safety: `my_sqlite_extension` is highly trustworthy. + /// unsafe { conn.load_extension("my_sqlite_extension", None) } + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + /// + /// # Safety + /// + /// This is equivalent to performing a `dlopen`/`LoadLibrary` on a shared + /// library, and calling a function inside, and thus requires that you trust + /// the library that you're loading. + /// + /// That is to say: to safely use this, the code in the extension must be + /// sound, trusted, correctly use the SQLite APIs, and not contain any + /// memory or thread safety errors. + #[cfg(feature = "load_extension")] + #[cfg_attr(docsrs, doc(cfg(feature = "load_extension")))] + #[inline] + pub unsafe fn load_extension>( + &self, + dylib_path: P, + entry_point: Option<&str>, + ) -> Result<()> { + self.db + .borrow_mut() + .load_extension(dylib_path.as_ref(), entry_point) + } + + /// Get access to the underlying SQLite database connection handle. + /// + /// # Warning + /// + /// You should not need to use this function. If you do need to, please + /// [open an issue on the rusqlite repository](https://github.com/rusqlite/rusqlite/issues) and describe + /// your use case. + /// + /// # Safety + /// + /// This function is unsafe because it gives you raw access + /// to the SQLite connection, and what you do with it could impact the + /// safety of this `Connection`. + #[inline] + pub unsafe fn handle(&self) -> *mut ffi::sqlite3 { + self.db.borrow().db() + } + + /// Create a `Connection` from a raw handle. + /// + /// The underlying SQLite database connection handle will not be closed when + /// the returned connection is dropped/closed. + /// + /// # Safety + /// + /// This function is unsafe because improper use may impact the Connection. + #[inline] + pub unsafe fn from_handle(db: *mut ffi::sqlite3) -> Result { + let db_path = db_filename(db); + let db = InnerConnection::new(db, false); + Ok(Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + path: db_path, + }) + } + + /// Get access to a handle that can be used to interrupt long running + /// queries from another thread. + #[inline] + pub fn get_interrupt_handle(&self) -> InterruptHandle { + self.db.borrow().get_interrupt_handle() + } + + #[inline] + fn decode_result(&self, code: c_int) -> Result<()> { + self.db.borrow().decode_result(code) + } + + /// Return the number of rows modified, inserted or deleted by the most + /// recently completed INSERT, UPDATE or DELETE statement on the database + /// connection. + /// + /// See + #[inline] + pub fn changes(&self) -> u64 { + self.db.borrow().changes() + } + + /// Test for auto-commit mode. + /// Autocommit mode is on by default. + #[inline] + pub fn is_autocommit(&self) -> bool { + self.db.borrow().is_autocommit() + } + + /// Determine if all associated prepared statements have been reset. + #[inline] + #[cfg(feature = "modern_sqlite")] // 3.8.6 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn is_busy(&self) -> bool { + self.db.borrow().is_busy() + } + + /// Flush caches to disk mid-transaction + #[cfg(feature = "modern_sqlite")] // 3.10.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn cache_flush(&self) -> Result<()> { + self.db.borrow_mut().cache_flush() + } + + /// Determine if a database is read-only + #[cfg(feature = "modern_sqlite")] // 3.7.11 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn is_readonly(&self, db_name: DatabaseName<'_>) -> Result { + self.db.borrow().db_readonly(db_name) + } +} + +impl fmt::Debug for Connection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection") + .field("path", &self.path) + .finish() + } +} + +/// Batch iterator +/// ```rust +/// use rusqlite::{Batch, Connection, Result}; +/// +/// fn main() -> Result<()> { +/// let conn = Connection::open_in_memory()?; +/// let sql = r" +/// CREATE TABLE tbl1 (col); +/// CREATE TABLE tbl2 (col); +/// "; +/// let mut batch = Batch::new(&conn, sql); +/// while let Some(mut stmt) = batch.next()? { +/// stmt.execute([])?; +/// } +/// Ok(()) +/// } +/// ``` +#[derive(Debug)] +pub struct Batch<'conn, 'sql> { + conn: &'conn Connection, + sql: &'sql str, + tail: usize, +} + +impl<'conn, 'sql> Batch<'conn, 'sql> { + /// Constructor + pub fn new(conn: &'conn Connection, sql: &'sql str) -> Batch<'conn, 'sql> { + Batch { conn, sql, tail: 0 } + } + + /// Iterates on each batch statements. + /// + /// Returns `Ok(None)` when batch is completed. + #[allow(clippy::should_implement_trait)] // fallible iterator + pub fn next(&mut self) -> Result>> { + while self.tail < self.sql.len() { + let sql = &self.sql[self.tail..]; + let next = self.conn.prepare(sql)?; + let tail = next.stmt.tail(); + if tail == 0 { + self.tail = self.sql.len(); + } else { + self.tail += tail; + } + if next.stmt.is_null() { + continue; + } + return Ok(Some(next)); + } + Ok(None) + } +} + +impl<'conn> Iterator for Batch<'conn, '_> { + type Item = Result>; + + fn next(&mut self) -> Option>> { + self.next().transpose() + } +} + +bitflags::bitflags! { + /// Flags for opening SQLite database connections. See + /// [sqlite3_open_v2](http://www.sqlite.org/c3ref/open.html) for details. + /// + /// The default open flags are `SQLITE_OPEN_READ_WRITE | SQLITE_OPEN_CREATE + /// | SQLITE_OPEN_URI | SQLITE_OPEN_NO_MUTEX`. See [`Connection::open`] for + /// some discussion about these flags. + #[repr(C)] + pub struct OpenFlags: ::std::os::raw::c_int { + /// The database is opened in read-only mode. + /// If the database does not already exist, an error is returned. + const SQLITE_OPEN_READ_ONLY = ffi::SQLITE_OPEN_READONLY; + /// The database is opened for reading and writing if possible, + /// or reading only if the file is write protected by the operating system. + /// In either case the database must already exist, otherwise an error is returned. + const SQLITE_OPEN_READ_WRITE = ffi::SQLITE_OPEN_READWRITE; + /// The database is created if it does not already exist + const SQLITE_OPEN_CREATE = ffi::SQLITE_OPEN_CREATE; + /// The filename can be interpreted as a URI if this flag is set. + const SQLITE_OPEN_URI = 0x0000_0040; + /// The database will be opened as an in-memory database. + const SQLITE_OPEN_MEMORY = 0x0000_0080; + /// The new database connection will not use a per-connection mutex (the + /// connection will use the "multi-thread" threading mode, in SQLite + /// parlance). + /// + /// This is used by default, as proper `Send`/`Sync` usage (in + /// particular, the fact that [`Connection`] does not implement `Sync`) + /// ensures thread-safety without the need to perform locking around all + /// calls. + const SQLITE_OPEN_NO_MUTEX = ffi::SQLITE_OPEN_NOMUTEX; + /// The new database connection will use a per-connection mutex -- the + /// "serialized" threading mode, in SQLite parlance. + /// + /// # Caveats + /// + /// This flag should probably never be used with `rusqlite`, as we + /// ensure thread-safety statically (we implement [`Send`] and not + /// [`Sync`]). That said + /// + /// Critically, even if this flag is used, the [`Connection`] is not + /// safe to use across multiple threads simultaneously. To access a + /// database from multiple threads, you should either create multiple + /// connections, one for each thread (if you have very many threads, + /// wrapping the `rusqlite::Connection` in a mutex is also reasonable). + /// + /// This is both because of the additional per-connection state stored + /// by `rusqlite` (for example, the prepared statement cache), and + /// because not all of SQLites functions are fully thread safe, even in + /// serialized/`SQLITE_OPEN_FULLMUTEX` mode. + /// + /// All that said, it's fairly harmless to enable this flag with + /// `rusqlite`, it will just slow things down while providing no + /// benefit. + const SQLITE_OPEN_FULL_MUTEX = ffi::SQLITE_OPEN_FULLMUTEX; + /// The database is opened with shared cache enabled. + /// + /// This is frequently useful for in-memory connections, but note that + /// broadly speaking it's discouraged by SQLite itself, which states + /// "Any use of shared cache is discouraged" in the official + /// [documentation](https://www.sqlite.org/c3ref/enable_shared_cache.html). + const SQLITE_OPEN_SHARED_CACHE = 0x0002_0000; + /// The database is opened shared cache disabled. + const SQLITE_OPEN_PRIVATE_CACHE = 0x0004_0000; + /// The database filename is not allowed to be a symbolic link. (3.31.0) + const SQLITE_OPEN_NOFOLLOW = 0x0100_0000; + /// Extended result codes. (3.37.0) + const SQLITE_OPEN_EXRESCODE = 0x0200_0000; + } +} + +impl Default for OpenFlags { + #[inline] + fn default() -> OpenFlags { + // Note: update the `Connection::open` and top-level `OpenFlags` docs if + // you change these. + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_NO_MUTEX + | OpenFlags::SQLITE_OPEN_URI + } +} + +/// rusqlite's check for a safe SQLite threading mode requires SQLite 3.7.0 or +/// later. If you are running against a SQLite older than that, rusqlite +/// attempts to ensure safety by performing configuration and initialization of +/// SQLite itself the first time you +/// attempt to open a connection. By default, rusqlite panics if that +/// initialization fails, since that could mean SQLite has been initialized in +/// single-thread mode. +/// +/// If you are encountering that panic _and_ can ensure that SQLite has been +/// initialized in either multi-thread or serialized mode, call this function +/// prior to attempting to open a connection and rusqlite's initialization +/// process will by skipped. +/// +/// # Safety +/// +/// This function is unsafe because if you call it and SQLite has actually been +/// configured to run in single-thread mode, +/// you may encounter memory errors or data corruption or any number of terrible +/// things that should not be possible when you're using Rust. +pub unsafe fn bypass_sqlite_initialization() { + BYPASS_SQLITE_INIT.store(true, Ordering::Relaxed); +} + +/// Allows interrupting a long-running computation. +pub struct InterruptHandle { + db_lock: Arc>, +} + +unsafe impl Send for InterruptHandle {} +unsafe impl Sync for InterruptHandle {} + +impl InterruptHandle { + /// Interrupt the query currently executing on another thread. This will + /// cause that query to fail with a `SQLITE3_INTERRUPT` error. + pub fn interrupt(&self) { + let db_handle = self.db_lock.lock().unwrap(); + if !db_handle.is_null() { + unsafe { ffi::sqlite3_interrupt(*db_handle) } + } + } +} + +#[cfg(feature = "modern_sqlite")] // 3.7.10 +unsafe fn db_filename(db: *mut ffi::sqlite3) -> Option { + let db_name = DatabaseName::Main.as_cstring().unwrap(); + let db_filename = ffi::sqlite3_db_filename(db, db_name.as_ptr()); + if db_filename.is_null() { + None + } else { + CStr::from_ptr(db_filename).to_str().ok().map(PathBuf::from) + } +} +#[cfg(not(feature = "modern_sqlite"))] +unsafe fn db_filename(_: *mut ffi::sqlite3) -> Option { + None +} + +#[cfg(doctest)] +doc_comment::doctest!("../README.md"); + +#[cfg(test)] +mod test { + use super::*; + use crate::ffi; + use fallible_iterator::FallibleIterator; + use std::error::Error as StdError; + use std::fmt; + + // this function is never called, but is still type checked; in + // particular, calls with specific instantiations will require + // that those types are `Send`. + #[allow(dead_code, unconditional_recursion)] + fn ensure_send() { + ensure_send::(); + ensure_send::(); + } + + #[allow(dead_code, unconditional_recursion)] + fn ensure_sync() { + ensure_sync::(); + } + + fn checked_memory_handle() -> Connection { + Connection::open_in_memory().unwrap() + } + + #[test] + fn test_concurrent_transactions_busy_commit() -> Result<()> { + use std::time::Duration; + let tmp = tempfile::tempdir().unwrap(); + let path = tmp.path().join("transactions.db3"); + + Connection::open(&path)?.execute_batch( + " + BEGIN; CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); END;", + )?; + + let mut db1 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE)?; + let mut db2 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY)?; + + db1.busy_timeout(Duration::from_millis(0))?; + db2.busy_timeout(Duration::from_millis(0))?; + + { + let tx1 = db1.transaction()?; + let tx2 = db2.transaction()?; + + // SELECT first makes sqlite lock with a shared lock + tx1.query_row("SELECT x FROM foo LIMIT 1", [], |_| Ok(()))?; + tx2.query_row("SELECT x FROM foo LIMIT 1", [], |_| Ok(()))?; + + tx1.execute("INSERT INTO foo VALUES(?1)", [1])?; + let _ = tx2.execute("INSERT INTO foo VALUES(?1)", [2]); + + let _ = tx1.commit(); + let _ = tx2.commit(); + } + + let _ = db1 + .transaction() + .expect("commit should have closed transaction"); + let _ = db2 + .transaction() + .expect("commit should have closed transaction"); + Ok(()) + } + + #[test] + fn test_persistence() -> Result<()> { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + { + let db = Connection::open(&path)?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + db.execute_batch(sql)?; + } + + let path_string = path.to_str().unwrap(); + let db = Connection::open(&path_string)?; + let the_answer: Result = db.query_row("SELECT x FROM foo", [], |r| r.get(0)); + + assert_eq!(42i64, the_answer?); + Ok(()) + } + + #[test] + fn test_open() { + assert!(Connection::open_in_memory().is_ok()); + + let db = checked_memory_handle(); + assert!(db.close().is_ok()); + } + + #[test] + fn test_open_failure() { + let filename = "no_such_file.db"; + let result = Connection::open_with_flags(filename, OpenFlags::SQLITE_OPEN_READ_ONLY); + assert!(result.is_err()); + let err = result.unwrap_err(); + if let Error::SqliteFailure(e, Some(msg)) = err { + assert_eq!(ErrorCode::CannotOpen, e.code); + assert_eq!(ffi::SQLITE_CANTOPEN, e.extended_code); + assert!( + msg.contains(filename), + "error message '{}' does not contain '{}'", + msg, + filename + ); + } else { + panic!("SqliteFailure expected"); + } + } + + #[cfg(unix)] + #[test] + fn test_invalid_unicode_file_names() -> Result<()> { + use std::ffi::OsStr; + use std::fs::File; + use std::os::unix::ffi::OsStrExt; + let temp_dir = tempfile::tempdir().unwrap(); + + let path = temp_dir.path(); + if File::create(path.join(OsStr::from_bytes(&[0xFE]))).is_err() { + // Skip test, filesystem doesn't support invalid Unicode + return Ok(()); + } + let db_path = path.join(OsStr::from_bytes(&[0xFF])); + { + let db = Connection::open(&db_path)?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + db.execute_batch(sql)?; + } + + let db = Connection::open(&db_path)?; + let the_answer: Result = db.query_row("SELECT x FROM foo", [], |r| r.get(0)); + + assert_eq!(42i64, the_answer?); + Ok(()) + } + + #[test] + fn test_close_retry() -> Result<()> { + let db = Connection::open_in_memory()?; + + // force the DB to be busy by preparing a statement; this must be done at the + // FFI level to allow us to call .close() without dropping the prepared + // statement first. + let raw_stmt = { + use super::str_to_cstring; + use std::os::raw::c_int; + use std::ptr; + + let raw_db = db.db.borrow_mut().db; + let sql = "SELECT 1"; + let mut raw_stmt: *mut ffi::sqlite3_stmt = ptr::null_mut(); + let cstring = str_to_cstring(sql)?; + let rc = unsafe { + ffi::sqlite3_prepare_v2( + raw_db, + cstring.as_ptr(), + (sql.len() + 1) as c_int, + &mut raw_stmt, + ptr::null_mut(), + ) + }; + assert_eq!(rc, ffi::SQLITE_OK); + raw_stmt + }; + + // now that we have an open statement, trying (and retrying) to close should + // fail. + let (db, _) = db.close().unwrap_err(); + let (db, _) = db.close().unwrap_err(); + let (db, _) = db.close().unwrap_err(); + + // finalize the open statement so a final close will succeed + assert_eq!(ffi::SQLITE_OK, unsafe { ffi::sqlite3_finalize(raw_stmt) }); + + db.close().unwrap(); + Ok(()) + } + + #[test] + fn test_open_with_flags() { + for bad_flags in &[ + OpenFlags::empty(), + OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_READ_WRITE, + OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_CREATE, + ] { + assert!(Connection::open_in_memory_with_flags(*bad_flags).is_err()); + } + } + + #[test] + fn test_execute_batch() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + INSERT INTO foo VALUES(3); + INSERT INTO foo VALUES(4); + END;"; + db.execute_batch(sql)?; + + db.execute_batch("UPDATE foo SET x = 3 WHERE x < 3")?; + + assert!(db.execute_batch("INVALID SQL").is_err()); + Ok(()) + } + + #[test] + fn test_execute() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER)")?; + + assert_eq!(1, db.execute("INSERT INTO foo(x) VALUES (?)", [1i32])?); + assert_eq!(1, db.execute("INSERT INTO foo(x) VALUES (?)", [2i32])?); + + assert_eq!( + 3i32, + db.query_row::("SELECT SUM(x) FROM foo", [], |r| r.get(0))? + ); + Ok(()) + } + + #[test] + #[cfg(feature = "extra_check")] + fn test_execute_select() { + let db = checked_memory_handle(); + let err = db.execute("SELECT 1 WHERE 1 < ?", [1i32]).unwrap_err(); + assert_eq!( + err, + Error::ExecuteReturnedResults, + "Unexpected error: {}", + err + ); + } + + #[test] + #[cfg(feature = "extra_check")] + fn test_execute_multiple() { + let db = checked_memory_handle(); + let err = db + .execute( + "CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)", + [], + ) + .unwrap_err(); + match err { + Error::MultipleStatement => (), + _ => panic!("Unexpected error: {}", err), + } + } + + #[test] + fn test_prepare_column_names() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + + let stmt = db.prepare("SELECT * FROM foo")?; + assert_eq!(stmt.column_count(), 1); + assert_eq!(stmt.column_names(), vec!["x"]); + + let stmt = db.prepare("SELECT x AS a, x AS b FROM foo")?; + assert_eq!(stmt.column_count(), 2); + assert_eq!(stmt.column_names(), vec!["a", "b"]); + Ok(()) + } + + #[test] + fn test_prepare_execute() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + + let mut insert_stmt = db.prepare("INSERT INTO foo(x) VALUES(?)")?; + assert_eq!(insert_stmt.execute([1i32])?, 1); + assert_eq!(insert_stmt.execute([2i32])?, 1); + assert_eq!(insert_stmt.execute([3i32])?, 1); + + assert_eq!(insert_stmt.execute(["hello"])?, 1); + assert_eq!(insert_stmt.execute(["goodbye"])?, 1); + assert_eq!(insert_stmt.execute([types::Null])?, 1); + + let mut update_stmt = db.prepare("UPDATE foo SET x=? WHERE x Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + + let mut insert_stmt = db.prepare("INSERT INTO foo(x) VALUES(?)")?; + assert_eq!(insert_stmt.execute([1i32])?, 1); + assert_eq!(insert_stmt.execute([2i32])?, 1); + assert_eq!(insert_stmt.execute([3i32])?, 1); + + let mut query = db.prepare("SELECT x FROM foo WHERE x < ? ORDER BY x DESC")?; + { + let mut rows = query.query([4i32])?; + let mut v = Vec::::new(); + + while let Some(row) = rows.next()? { + v.push(row.get(0)?); + } + + assert_eq!(v, [3i32, 2, 1]); + } + + { + let mut rows = query.query([3i32])?; + let mut v = Vec::::new(); + + while let Some(row) = rows.next()? { + v.push(row.get(0)?); + } + + assert_eq!(v, [2i32, 1]); + } + Ok(()) + } + + #[test] + fn test_query_map() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let results: Result> = query.query([])?.map(|row| row.get(1)).collect(); + + assert_eq!(results?.concat(), "hello, world!"); + Ok(()) + } + + #[test] + fn test_query_row() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + INSERT INTO foo VALUES(3); + INSERT INTO foo VALUES(4); + END;"; + db.execute_batch(sql)?; + + assert_eq!( + 10i64, + db.query_row::("SELECT SUM(x) FROM foo", [], |r| r.get(0))? + ); + + let result: Result = db.query_row("SELECT x FROM foo WHERE x > 5", [], |r| r.get(0)); + match result.unwrap_err() { + Error::QueryReturnedNoRows => (), + err => panic!("Unexpected error {}", err), + } + + let bad_query_result = db.query_row("NOT A PROPER QUERY; test123", [], |_| Ok(())); + + assert!(bad_query_result.is_err()); + Ok(()) + } + + #[test] + fn test_optional() -> Result<()> { + let db = Connection::open_in_memory()?; + + let result: Result = db.query_row("SELECT 1 WHERE 0 <> 0", [], |r| r.get(0)); + let result = result.optional(); + match result? { + None => (), + _ => panic!("Unexpected result"), + } + + let result: Result = db.query_row("SELECT 1 WHERE 0 == 0", [], |r| r.get(0)); + let result = result.optional(); + match result? { + Some(1) => (), + _ => panic!("Unexpected result"), + } + + let bad_query_result: Result = db.query_row("NOT A PROPER QUERY", [], |r| r.get(0)); + let bad_query_result = bad_query_result.optional(); + assert!(bad_query_result.is_err()); + Ok(()) + } + + #[test] + fn test_pragma_query_row() -> Result<()> { + let db = Connection::open_in_memory()?; + assert_eq!( + "memory", + db.query_row::("PRAGMA journal_mode", [], |r| r.get(0))? + ); + let mode = db.query_row::("PRAGMA journal_mode=off", [], |r| r.get(0))?; + if cfg!(features = "bundled") { + assert_eq!(mode, "off"); + } else { + // Note: system SQLite on macOS defaults to "off" rather than + // "memory" for the journal mode (which cannot be changed for + // in-memory connections). This seems like it's *probably* legal + // according to the docs below, so we relax this test when not + // bundling: + // + // From https://www.sqlite.org/pragma.html#pragma_journal_mode + // > Note that the journal_mode for an in-memory database is either + // > MEMORY or OFF and can not be changed to a different value. An + // > attempt to change the journal_mode of an in-memory database to + // > any setting other than MEMORY or OFF is ignored. + assert!(mode == "memory" || mode == "off", "Got mode {:?}", mode); + } + + Ok(()) + } + + #[test] + fn test_prepare_failures() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + + let err = db.prepare("SELECT * FROM does_not_exist").unwrap_err(); + assert!(format!("{}", err).contains("does_not_exist")); + Ok(()) + } + + #[test] + fn test_last_insert_rowid() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER PRIMARY KEY)")?; + db.execute_batch("INSERT INTO foo DEFAULT VALUES")?; + + assert_eq!(db.last_insert_rowid(), 1); + + let mut stmt = db.prepare("INSERT INTO foo DEFAULT VALUES")?; + for _ in 0i32..9 { + stmt.execute([])?; + } + assert_eq!(db.last_insert_rowid(), 10); + Ok(()) + } + + #[test] + fn test_is_autocommit() -> Result<()> { + let db = Connection::open_in_memory()?; + assert!( + db.is_autocommit(), + "autocommit expected to be active by default" + ); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_is_busy() -> Result<()> { + let db = Connection::open_in_memory()?; + assert!(!db.is_busy()); + let mut stmt = db.prepare("PRAGMA schema_version")?; + assert!(!db.is_busy()); + { + let mut rows = stmt.query([])?; + assert!(!db.is_busy()); + let row = rows.next()?; + assert!(db.is_busy()); + assert!(row.is_some()); + } + assert!(!db.is_busy()); + Ok(()) + } + + #[test] + fn test_statement_debugging() -> Result<()> { + let db = Connection::open_in_memory()?; + let query = "SELECT 12345"; + let stmt = db.prepare(query)?; + + assert!(format!("{:?}", stmt).contains(query)); + Ok(()) + } + + #[test] + fn test_notnull_constraint_error() -> Result<()> { + // extended error codes for constraints were added in SQLite 3.7.16; if we're + // running on our bundled version, we know the extended error code exists. + #[cfg(feature = "modern_sqlite")] + fn check_extended_code(extended_code: c_int) { + assert_eq!(extended_code, ffi::SQLITE_CONSTRAINT_NOTNULL); + } + #[cfg(not(feature = "modern_sqlite"))] + fn check_extended_code(_extended_code: c_int) {} + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x NOT NULL)")?; + + let result = db.execute("INSERT INTO foo (x) VALUES (NULL)", []); + assert!(result.is_err()); + + match result.unwrap_err() { + Error::SqliteFailure(err, _) => { + assert_eq!(err.code, ErrorCode::ConstraintViolation); + check_extended_code(err.extended_code); + } + err => panic!("Unexpected error {}", err), + } + Ok(()) + } + + #[test] + fn test_version_string() { + let n = version_number(); + let major = n / 1_000_000; + let minor = (n % 1_000_000) / 1_000; + let patch = n % 1_000; + + assert!(version().contains(&format!("{}.{}.{}", major, minor, patch))); + } + + #[test] + #[cfg(feature = "functions")] + fn test_interrupt() -> Result<()> { + let db = Connection::open_in_memory()?; + + let interrupt_handle = db.get_interrupt_handle(); + + db.create_scalar_function( + "interrupt", + 0, + functions::FunctionFlags::default(), + move |_| { + interrupt_handle.interrupt(); + Ok(0) + }, + )?; + + let mut stmt = + db.prepare("SELECT interrupt() FROM (SELECT 1 UNION SELECT 2 UNION SELECT 3)")?; + + let result: Result> = stmt.query([])?.map(|r| r.get(0)).collect(); + + assert_eq!( + result.unwrap_err().sqlite_error_code(), + Some(ErrorCode::OperationInterrupted) + ); + Ok(()) + } + + #[test] + fn test_interrupt_close() { + let db = checked_memory_handle(); + let handle = db.get_interrupt_handle(); + handle.interrupt(); + db.close().unwrap(); + handle.interrupt(); + + // Look at it's internals to see if we cleared it out properly. + let db_guard = handle.db_lock.lock().unwrap(); + assert!(db_guard.is_null()); + // It would be nice to test that we properly handle close/interrupt + // running at the same time, but it seems impossible to do with any + // degree of reliability. + } + + #[test] + fn test_get_raw() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(i, x);")?; + let vals = ["foobar", "1234", "qwerty"]; + let mut insert_stmt = db.prepare("INSERT INTO foo(i, x) VALUES(?, ?)")?; + for (i, v) in vals.iter().enumerate() { + let i_to_insert = i as i64; + assert_eq!(insert_stmt.execute(params![i_to_insert, v])?, 1); + } + + let mut query = db.prepare("SELECT i, x FROM foo")?; + let mut rows = query.query([])?; + + while let Some(row) = rows.next()? { + let i = row.get_ref(0)?.as_i64()?; + let expect = vals[i as usize]; + let x = row.get_ref("x")?.as_str()?; + assert_eq!(x, expect); + } + + let mut query = db.prepare("SELECT x FROM foo")?; + let rows = query.query_map([], |row| { + let x = row.get_ref(0)?.as_str()?; // check From for Error + Ok(x[..].to_owned()) + })?; + + for (i, row) in rows.enumerate() { + assert_eq!(row?, vals[i]); + } + Ok(()) + } + + #[test] + fn test_from_handle() -> Result<()> { + let db = Connection::open_in_memory()?; + let handle = unsafe { db.handle() }; + { + let db = unsafe { Connection::from_handle(handle) }?; + db.execute_batch("PRAGMA VACUUM")?; + } + db.close().unwrap(); + Ok(()) + } + + mod query_and_then_tests { + + use super::*; + + #[derive(Debug)] + enum CustomError { + SomeError, + Sqlite(Error), + } + + impl fmt::Display for CustomError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match *self { + CustomError::SomeError => write!(f, "my custom error"), + CustomError::Sqlite(ref se) => write!(f, "my custom error: {}", se), + } + } + } + + impl StdError for CustomError { + fn description(&self) -> &str { + "my custom error" + } + + fn cause(&self) -> Option<&dyn StdError> { + match *self { + CustomError::SomeError => None, + CustomError::Sqlite(ref se) => Some(se), + } + } + } + + impl From for CustomError { + fn from(se: Error) -> CustomError { + CustomError::Sqlite(se) + } + } + + type CustomResult = Result; + + #[test] + fn test_query_and_then() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let results: Result> = + query.query_and_then([], |row| row.get(1))?.collect(); + + assert_eq!(results?.concat(), "hello, world!"); + Ok(()) + } + + #[test] + fn test_query_and_then_fails() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let bad_type: Result> = query.query_and_then([], |row| row.get(1))?.collect(); + + match bad_type.unwrap_err() { + Error::InvalidColumnType(..) => (), + err => panic!("Unexpected error {}", err), + } + + let bad_idx: Result> = + query.query_and_then([], |row| row.get(3))?.collect(); + + match bad_idx.unwrap_err() { + Error::InvalidColumnIndex(_) => (), + err => panic!("Unexpected error {}", err), + } + Ok(()) + } + + #[test] + fn test_query_and_then_custom_error() -> CustomResult<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let results: CustomResult> = query + .query_and_then([], |row| row.get(1).map_err(CustomError::Sqlite))? + .collect(); + + assert_eq!(results?.concat(), "hello, world!"); + Ok(()) + } + + #[test] + fn test_query_and_then_custom_error_fails() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let bad_type: CustomResult> = query + .query_and_then([], |row| row.get(1).map_err(CustomError::Sqlite))? + .collect(); + + match bad_type.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnType(..)) => (), + err => panic!("Unexpected error {}", err), + } + + let bad_idx: CustomResult> = query + .query_and_then([], |row| row.get(3).map_err(CustomError::Sqlite))? + .collect(); + + match bad_idx.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnIndex(_)) => (), + err => panic!("Unexpected error {}", err), + } + + let non_sqlite_err: CustomResult> = query + .query_and_then([], |_| Err(CustomError::SomeError))? + .collect(); + + match non_sqlite_err.unwrap_err() { + CustomError::SomeError => (), + err => panic!("Unexpected error {}", err), + } + Ok(()) + } + + #[test] + fn test_query_row_and_then_custom_error() -> CustomResult<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql)?; + + let query = "SELECT x, y FROM foo ORDER BY x DESC"; + let results: CustomResult = + db.query_row_and_then(query, [], |row| row.get(1).map_err(CustomError::Sqlite)); + + assert_eq!(results?, "hello"); + Ok(()) + } + + #[test] + fn test_query_row_and_then_custom_error_fails() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql)?; + + let query = "SELECT x, y FROM foo ORDER BY x DESC"; + let bad_type: CustomResult = + db.query_row_and_then(query, [], |row| row.get(1).map_err(CustomError::Sqlite)); + + match bad_type.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnType(..)) => (), + err => panic!("Unexpected error {}", err), + } + + let bad_idx: CustomResult = + db.query_row_and_then(query, [], |row| row.get(3).map_err(CustomError::Sqlite)); + + match bad_idx.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnIndex(_)) => (), + err => panic!("Unexpected error {}", err), + } + + let non_sqlite_err: CustomResult = + db.query_row_and_then(query, [], |_| Err(CustomError::SomeError)); + + match non_sqlite_err.unwrap_err() { + CustomError::SomeError => (), + err => panic!("Unexpected error {}", err), + } + Ok(()) + } + } + + #[test] + fn test_dynamic() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql)?; + + db.query_row("SELECT * FROM foo", [], |r| { + assert_eq!(2, r.as_ref().column_count()); + Ok(()) + }) + } + #[test] + fn test_dyn_box() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + let b: Box = Box::new(5); + db.execute("INSERT INTO foo VALUES(?)", [b])?; + db.query_row("SELECT x FROM foo", [], |r| { + assert_eq!(5, r.get_unwrap::<_, i32>(0)); + Ok(()) + }) + } + + #[test] + fn test_params() -> Result<()> { + let db = Connection::open_in_memory()?; + db.query_row( + "SELECT + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, + ?, ?, ?, ?;", + params![ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, + ], + |r| { + assert_eq!(1, r.get_unwrap::<_, i32>(0)); + Ok(()) + }, + ) + } + + #[test] + #[cfg(not(feature = "extra_check"))] + fn test_alter_table() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE x(t);")?; + // `execute_batch` should be used but `execute` should also work + db.execute("ALTER TABLE x RENAME TO y;", [])?; + Ok(()) + } + + #[test] + fn test_batch() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r" + CREATE TABLE tbl1 (col); + CREATE TABLE tbl2 (col); + "; + let batch = Batch::new(&db, sql); + for stmt in batch { + let mut stmt = stmt?; + stmt.execute([])?; + } + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_returning() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER PRIMARY KEY)")?; + let row_id = + db.query_row::("INSERT INTO foo DEFAULT VALUES RETURNING ROWID", [], |r| { + r.get(0) + })?; + assert_eq!(row_id, 1); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_cache_flush() -> Result<()> { + let db = Connection::open_in_memory()?; + db.cache_flush() + } + + #[test] + #[cfg(feature = "modern_sqlite")] + pub fn db_readonly() -> Result<()> { + let db = Connection::open_in_memory()?; + assert!(!db.is_readonly(MAIN_DB)?); + Ok(()) + } +} diff --git a/src/limits.rs b/src/limits.rs new file mode 100644 index 0000000..93e0bb0 --- /dev/null +++ b/src/limits.rs @@ -0,0 +1,169 @@ +//! Run-Time Limits + +use crate::{ffi, Connection}; +use std::os::raw::c_int; + +/// Run-Time limit categories, for use with [`Connection::limit`] and +/// [`Connection::set_limit`]. +/// +/// See the official documentation for more information: +/// - +/// - +#[repr(i32)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms, non_camel_case_types)] +#[cfg_attr(docsrs, doc(cfg(feature = "limits")))] +pub enum Limit { + /// The maximum size of any string or BLOB or table row, in bytes. + SQLITE_LIMIT_LENGTH = ffi::SQLITE_LIMIT_LENGTH, + /// The maximum length of an SQL statement, in bytes. + SQLITE_LIMIT_SQL_LENGTH = ffi::SQLITE_LIMIT_SQL_LENGTH, + /// The maximum number of columns in a table definition or in the result set + /// of a SELECT or the maximum number of columns in an index or in an + /// ORDER BY or GROUP BY clause. + SQLITE_LIMIT_COLUMN = ffi::SQLITE_LIMIT_COLUMN, + /// The maximum depth of the parse tree on any expression. + SQLITE_LIMIT_EXPR_DEPTH = ffi::SQLITE_LIMIT_EXPR_DEPTH, + /// The maximum number of terms in a compound SELECT statement. + SQLITE_LIMIT_COMPOUND_SELECT = ffi::SQLITE_LIMIT_COMPOUND_SELECT, + /// The maximum number of instructions in a virtual machine program used to + /// implement an SQL statement. + SQLITE_LIMIT_VDBE_OP = ffi::SQLITE_LIMIT_VDBE_OP, + /// The maximum number of arguments on a function. + SQLITE_LIMIT_FUNCTION_ARG = ffi::SQLITE_LIMIT_FUNCTION_ARG, + /// The maximum number of attached databases. + SQLITE_LIMIT_ATTACHED = ffi::SQLITE_LIMIT_ATTACHED, + /// The maximum length of the pattern argument to the LIKE or GLOB + /// operators. + SQLITE_LIMIT_LIKE_PATTERN_LENGTH = ffi::SQLITE_LIMIT_LIKE_PATTERN_LENGTH, + /// The maximum index number of any parameter in an SQL statement. + SQLITE_LIMIT_VARIABLE_NUMBER = ffi::SQLITE_LIMIT_VARIABLE_NUMBER, + /// The maximum depth of recursion for triggers. + SQLITE_LIMIT_TRIGGER_DEPTH = 10, + /// The maximum number of auxiliary worker threads that a single prepared + /// statement may start. + SQLITE_LIMIT_WORKER_THREADS = 11, +} + +impl Connection { + /// Returns the current value of a [`Limit`]. + #[inline] + #[cfg_attr(docsrs, doc(cfg(feature = "limits")))] + pub fn limit(&self, limit: Limit) -> i32 { + let c = self.db.borrow(); + unsafe { ffi::sqlite3_limit(c.db(), limit as c_int, -1) } + } + + /// Changes the [`Limit`] to `new_val`, returning the prior + /// value of the limit. + #[inline] + #[cfg_attr(docsrs, doc(cfg(feature = "limits")))] + pub fn set_limit(&self, limit: Limit, new_val: i32) -> i32 { + let c = self.db.borrow_mut(); + unsafe { ffi::sqlite3_limit(c.db(), limit as c_int, new_val) } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{Connection, Result}; + + #[test] + fn test_limit_values() { + assert_eq!( + Limit::SQLITE_LIMIT_LENGTH as i32, + ffi::SQLITE_LIMIT_LENGTH as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_SQL_LENGTH as i32, + ffi::SQLITE_LIMIT_SQL_LENGTH as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_COLUMN as i32, + ffi::SQLITE_LIMIT_COLUMN as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_EXPR_DEPTH as i32, + ffi::SQLITE_LIMIT_EXPR_DEPTH as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_COMPOUND_SELECT as i32, + ffi::SQLITE_LIMIT_COMPOUND_SELECT as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_VDBE_OP as i32, + ffi::SQLITE_LIMIT_VDBE_OP as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_FUNCTION_ARG as i32, + ffi::SQLITE_LIMIT_FUNCTION_ARG as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_ATTACHED as i32, + ffi::SQLITE_LIMIT_ATTACHED as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_LIKE_PATTERN_LENGTH as i32, + ffi::SQLITE_LIMIT_LIKE_PATTERN_LENGTH as i32, + ); + assert_eq!( + Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32, + ffi::SQLITE_LIMIT_VARIABLE_NUMBER as i32, + ); + #[cfg(feature = "bundled")] + assert_eq!( + Limit::SQLITE_LIMIT_TRIGGER_DEPTH as i32, + ffi::SQLITE_LIMIT_TRIGGER_DEPTH as i32, + ); + #[cfg(feature = "bundled")] + assert_eq!( + Limit::SQLITE_LIMIT_WORKER_THREADS as i32, + ffi::SQLITE_LIMIT_WORKER_THREADS as i32, + ); + } + + #[test] + fn test_limit() -> Result<()> { + let db = Connection::open_in_memory()?; + db.set_limit(Limit::SQLITE_LIMIT_LENGTH, 1024); + assert_eq!(1024, db.limit(Limit::SQLITE_LIMIT_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_SQL_LENGTH, 1024); + assert_eq!(1024, db.limit(Limit::SQLITE_LIMIT_SQL_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_COLUMN, 64); + assert_eq!(64, db.limit(Limit::SQLITE_LIMIT_COLUMN)); + + db.set_limit(Limit::SQLITE_LIMIT_EXPR_DEPTH, 256); + assert_eq!(256, db.limit(Limit::SQLITE_LIMIT_EXPR_DEPTH)); + + db.set_limit(Limit::SQLITE_LIMIT_COMPOUND_SELECT, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_COMPOUND_SELECT)); + + db.set_limit(Limit::SQLITE_LIMIT_FUNCTION_ARG, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_FUNCTION_ARG)); + + db.set_limit(Limit::SQLITE_LIMIT_ATTACHED, 2); + assert_eq!(2, db.limit(Limit::SQLITE_LIMIT_ATTACHED)); + + db.set_limit(Limit::SQLITE_LIMIT_LIKE_PATTERN_LENGTH, 128); + assert_eq!(128, db.limit(Limit::SQLITE_LIMIT_LIKE_PATTERN_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER, 99); + assert_eq!(99, db.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)); + + // SQLITE_LIMIT_TRIGGER_DEPTH was added in SQLite 3.6.18. + if crate::version_number() >= 3_006_018 { + db.set_limit(Limit::SQLITE_LIMIT_TRIGGER_DEPTH, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_TRIGGER_DEPTH)); + } + + // SQLITE_LIMIT_WORKER_THREADS was added in SQLite 3.8.7. + if crate::version_number() >= 3_008_007 { + db.set_limit(Limit::SQLITE_LIMIT_WORKER_THREADS, 2); + assert_eq!(2, db.limit(Limit::SQLITE_LIMIT_WORKER_THREADS)); + } + Ok(()) + } +} diff --git a/src/load_extension_guard.rs b/src/load_extension_guard.rs new file mode 100644 index 0000000..deed3b4 --- /dev/null +++ b/src/load_extension_guard.rs @@ -0,0 +1,46 @@ +use crate::{Connection, Result}; + +/// RAII guard temporarily enabling SQLite extensions to be loaded. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, LoadExtensionGuard}; +/// # use std::path::{Path}; +/// fn load_my_extension(conn: &Connection) -> Result<()> { +/// unsafe { +/// let _guard = LoadExtensionGuard::new(conn)?; +/// conn.load_extension("trusted/sqlite/extension", None) +/// } +/// } +/// ``` +#[cfg_attr(docsrs, doc(cfg(feature = "load_extension")))] +pub struct LoadExtensionGuard<'conn> { + conn: &'conn Connection, +} + +impl LoadExtensionGuard<'_> { + /// Attempt to enable loading extensions. Loading extensions will be + /// disabled when this guard goes out of scope. Cannot be meaningfully + /// nested. + /// + /// # Safety + /// + /// You must not run untrusted queries while extension loading is enabled. + /// + /// See the safety comment on [`Connection::load_extension_enable`] for more + /// details. + #[inline] + pub unsafe fn new(conn: &Connection) -> Result> { + conn.load_extension_enable() + .map(|_| LoadExtensionGuard { conn }) + } +} + +#[allow(unused_must_use)] +impl Drop for LoadExtensionGuard<'_> { + #[inline] + fn drop(&mut self) { + self.conn.load_extension_disable(); + } +} diff --git a/src/params.rs b/src/params.rs new file mode 100644 index 0000000..6ab6b5f --- /dev/null +++ b/src/params.rs @@ -0,0 +1,458 @@ +use crate::{Result, Statement, ToSql}; + +mod sealed { + /// This trait exists just to ensure that the only impls of `trait Params` + /// that are allowed are ones in this crate. + pub trait Sealed {} +} +use sealed::Sealed; + +/// Trait used for [sets of parameter][params] passed into SQL +/// statements/queries. +/// +/// [params]: https://www.sqlite.org/c3ref/bind_blob.html +/// +/// Note: Currently, this trait can only be implemented inside this crate. +/// Additionally, it's methods (which are `doc(hidden)`) should currently not be +/// considered part of the stable API, although it's possible they will +/// stabilize in the future. +/// +/// # Passing parameters to SQLite +/// +/// Many functions in this library let you pass parameters to SQLite. Doing this +/// lets you avoid any risk of SQL injection, and is simpler than escaping +/// things manually. Aside from deprecated functions and a few helpers, this is +/// indicated by the function taking a generic argument that implements `Params` +/// (this trait). +/// +/// ## Positional parameters +/// +/// For cases where you want to pass a list of parameters where the number of +/// parameters is known at compile time, this can be done in one of the +/// following ways: +/// +/// - For small lists of parameters up to 16 items, they may alternatively be +/// passed as a tuple, as in `thing.query((1, "foo"))`. +/// +/// This is somewhat inconvenient for a single item, since you need a +/// weird-looking trailing comma: `thing.query(("example",))`. That case is +/// perhaps more cleanly expressed as `thing.query(["example"])`. +/// +/// - Using the [`rusqlite::params!`](crate::params!) macro, e.g. +/// `thing.query(rusqlite::params![1, "foo", bar])`. This is mostly useful for +/// heterogeneous lists where the number of parameters greater than 16, or +/// homogenous lists of parameters where the number of parameters exceeds 32. +/// +/// - For small homogeneous lists of parameters, they can either be passed as: +/// +/// - an array, as in `thing.query([1i32, 2, 3, 4])` or `thing.query(["foo", +/// "bar", "baz"])`. +/// +/// - a reference to an array of references, as in `thing.query(&["foo", +/// "bar", "baz"])` or `thing.query(&[&1i32, &2, &3])`. +/// +/// (Note: in this case we don't implement this for slices for coherence +/// reasons, so it really is only for the "reference to array" types — +/// hence why the number of parameters must be <= 32 or you need to +/// reach for `rusqlite::params!`) +/// +/// Unfortunately, in the current design it's not possible to allow this for +/// references to arrays of non-references (e.g. `&[1i32, 2, 3]`). Code like +/// this should instead either use `params!`, an array literal, a `&[&dyn +/// ToSql]` or if none of those work, [`ParamsFromIter`]. +/// +/// - As a slice of `ToSql` trait object references, e.g. `&[&dyn ToSql]`. This +/// is mostly useful for passing parameter lists around as arguments without +/// having every function take a generic `P: Params`. +/// +/// ### Example (positional) +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, params}; +/// fn update_rows(conn: &Connection) -> Result<()> { +/// let mut stmt = conn.prepare("INSERT INTO test (a, b) VALUES (?, ?)")?; +/// +/// // Using a tuple: +/// stmt.execute((0, "foobar"))?; +/// +/// // Using `rusqlite::params!`: +/// stmt.execute(params![1i32, "blah"])?; +/// +/// // array literal — non-references +/// stmt.execute([2i32, 3i32])?; +/// +/// // array literal — references +/// stmt.execute(["foo", "bar"])?; +/// +/// // Slice literal, references: +/// stmt.execute(&[&2i32, &3i32])?; +/// +/// // Note: The types behind the references don't have to be `Sized` +/// stmt.execute(&["foo", "bar"])?; +/// +/// // However, this doesn't work (see above): +/// // stmt.execute(&[1i32, 2i32])?; +/// Ok(()) +/// } +/// ``` +/// +/// ## Named parameters +/// +/// SQLite lets you name parameters using a number of conventions (":foo", +/// "@foo", "$foo"). You can pass named parameters in to SQLite using rusqlite +/// in a few ways: +/// +/// - Using the [`rusqlite::named_params!`](crate::named_params!) macro, as in +/// `stmt.execute(named_params!{ ":name": "foo", ":age": 99 })`. Similar to +/// the `params` macro, this is most useful for heterogeneous lists of +/// parameters, or lists where the number of parameters exceeds 32. +/// +/// - As a slice of `&[(&str, &dyn ToSql)]`. This is what essentially all of +/// these boil down to in the end, conceptually at least. In theory you can +/// pass this as `stmt`. +/// +/// - As array references, similar to the positional params. This looks like +/// `thing.query(&[(":foo", &1i32), (":bar", &2i32)])` or +/// `thing.query(&[(":foo", "abc"), (":bar", "def")])`. +/// +/// Note: Unbound named parameters will be left to the value they previously +/// were bound with, falling back to `NULL` for parameters which have never been +/// bound. +/// +/// ### Example (named) +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, named_params}; +/// fn insert(conn: &Connection) -> Result<()> { +/// let mut stmt = conn.prepare("INSERT INTO test (key, value) VALUES (:key, :value)")?; +/// // Using `rusqlite::params!`: +/// stmt.execute(named_params! { ":key": "one", ":val": 2 })?; +/// // Alternatively: +/// stmt.execute(&[(":key", "three"), (":val", "four")])?; +/// // Or: +/// stmt.execute(&[(":key", &100), (":val", &200)])?; +/// Ok(()) +/// } +/// ``` +/// +/// ## No parameters +/// +/// You can just use an empty tuple or the empty array literal to run a query +/// that accepts no parameters. (The `rusqlite::NO_PARAMS` constant which was +/// common in previous versions of this library is no longer needed, and is now +/// deprecated). +/// +/// ### Example (no parameters) +/// +/// The empty tuple: +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, params}; +/// fn delete_all_users(conn: &Connection) -> Result<()> { +/// // You may also use `()`. +/// conn.execute("DELETE FROM users", ())?; +/// Ok(()) +/// } +/// ``` +/// +/// The empty array: +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, params}; +/// fn delete_all_users(conn: &Connection) -> Result<()> { +/// // Just use an empty array (e.g. `[]`) for no params. +/// conn.execute("DELETE FROM users", [])?; +/// Ok(()) +/// } +/// ``` +/// +/// ## Dynamic parameter list +/// +/// If you have a number of parameters which is unknown at compile time (for +/// example, building a dynamic query at runtime), you have two choices: +/// +/// - Use a `&[&dyn ToSql]`. This is often annoying to construct if you don't +/// already have this type on-hand. +/// - Use the [`ParamsFromIter`] type. This essentially lets you wrap an +/// iterator some `T: ToSql` with something that implements `Params`. The +/// usage of this looks like `rusqlite::params_from_iter(something)`. +/// +/// A lot of the considerations here are similar either way, so you should see +/// the [`ParamsFromIter`] documentation for more info / examples. +pub trait Params: Sealed { + // XXX not public api, might not need to expose. + // + // Binds the parameters to the statement. It is unlikely calling this + // explicitly will do what you want. Please use `Statement::query` or + // similar directly. + // + // For now, just hide the function in the docs... + #[doc(hidden)] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()>; +} + +// Explicitly impl for empty array. Critically, for `conn.execute([])` to be +// unambiguous, this must be the *only* implementation for an empty array. This +// avoids `NO_PARAMS` being a necessary part of the API. +// +// This sadly prevents `impl Params for [T; N]`, which +// forces people to use `params![...]` or `rusqlite::params_from_iter` for long +// homogenous lists of parameters. This is not that big of a deal, but is +// unfortunate, especially because I mostly did it because I wanted a simple +// syntax for no-params that didnt require importing -- the empty tuple fits +// that nicely, but I didn't think of it until much later. +// +// Admittedly, if we did have the generic impl, then we *wouldn't* support the +// empty array literal as a parameter, since the `T` there would fail to be +// inferred. The error message here would probably be quite bad, and so on +// further thought, probably would end up causing *more* surprises, not less. +impl Sealed for [&(dyn ToSql + Send + Sync); 0] {} +impl Params for [&(dyn ToSql + Send + Sync); 0] { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.ensure_parameter_count(0) + } +} + +impl Sealed for &[&dyn ToSql] {} +impl Params for &[&dyn ToSql] { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters(self) + } +} + +impl Sealed for &[(&str, &dyn ToSql)] {} +impl Params for &[(&str, &dyn ToSql)] { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters_named(self) + } +} + +// Manual impls for the empty and singleton tuple, although the rest are covered +// by macros. +impl Sealed for () {} +impl Params for () { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.ensure_parameter_count(0) + } +} + +// I'm pretty sure you could tweak the `single_tuple_impl` to accept this. +impl Sealed for (T,) {} +impl Params for (T,) { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.ensure_parameter_count(1)?; + stmt.raw_bind_parameter(1, self.0)?; + Ok(()) + } +} + +macro_rules! single_tuple_impl { + ($count:literal : $(($field:tt $ftype:ident)),* $(,)?) => { + impl<$($ftype,)*> Sealed for ($($ftype,)*) where $($ftype: ToSql,)* {} + impl<$($ftype,)*> Params for ($($ftype,)*) where $($ftype: ToSql,)* { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.ensure_parameter_count($count)?; + $({ + debug_assert!($field < $count); + stmt.raw_bind_parameter($field + 1, self.$field)?; + })+ + Ok(()) + } + } + } +} + +// We use a the macro for the rest, but don't bother with trying to implement it +// in a single invocation (it's possible to do, but my attempts were almost the +// same amount of code as just writing it out this way, and much more dense -- +// it is a more complicated case than the TryFrom macro we have for row->tuple). +// +// Note that going up to 16 (rather than the 12 that the impls in the stdlib +// usually support) is just because we did the same in the `TryFrom` impl. +// I didn't catch that then, but there's no reason to remove it, and it seems +// nice to be consistent here; this way putting data in the database and getting +// data out of the database are more symmetric in a (mostly superficial) sense. +single_tuple_impl!(2: (0 A), (1 B)); +single_tuple_impl!(3: (0 A), (1 B), (2 C)); +single_tuple_impl!(4: (0 A), (1 B), (2 C), (3 D)); +single_tuple_impl!(5: (0 A), (1 B), (2 C), (3 D), (4 E)); +single_tuple_impl!(6: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F)); +single_tuple_impl!(7: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G)); +single_tuple_impl!(8: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H)); +single_tuple_impl!(9: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I)); +single_tuple_impl!(10: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J)); +single_tuple_impl!(11: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K)); +single_tuple_impl!(12: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L)); +single_tuple_impl!(13: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L), (12 M)); +single_tuple_impl!(14: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L), (12 M), (13 N)); +single_tuple_impl!(15: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L), (12 M), (13 N), (14 O)); +single_tuple_impl!(16: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L), (12 M), (13 N), (14 O), (15 P)); + +macro_rules! impl_for_array_ref { + ($($N:literal)+) => {$( + // These are already generic, and there's a shedload of them, so lets + // avoid the compile time hit from making them all inline for now. + impl Sealed for &[&T; $N] {} + impl Params for &[&T; $N] { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters(self) + } + } + impl Sealed for &[(&str, &T); $N] {} + impl Params for &[(&str, &T); $N] { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters_named(self) + } + } + impl Sealed for [T; $N] {} + impl Params for [T; $N] { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters(&self) + } + } + )+}; +} + +// Following libstd/libcore's (old) lead, implement this for arrays up to `[_; +// 32]`. Note `[_; 0]` is intentionally omitted for coherence reasons, see the +// note above the impl of `[&dyn ToSql; 0]` for more information. +// +// Note that this unfortunately means we can't use const generics here, but I +// don't really think it matters -- users who hit that can use `params!` anyway. +impl_for_array_ref!( + 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 + 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 +); + +/// Adapter type which allows any iterator over [`ToSql`] values to implement +/// [`Params`]. +/// +/// This struct is created by the [`params_from_iter`] function. +/// +/// This can be useful if you have something like an `&[String]` (of unknown +/// length), and you want to use them with an API that wants something +/// implementing `Params`. This way, you can avoid having to allocate storage +/// for something like a `&[&dyn ToSql]`. +/// +/// This essentially is only ever actually needed when dynamically generating +/// SQL — static SQL (by definition) has the number of parameters known +/// statically. As dynamically generating SQL is itself pretty advanced, this +/// API is itself for advanced use cases (See "Realistic use case" in the +/// examples). +/// +/// # Example +/// +/// ## Basic usage +/// +/// ```rust,no_run +/// use rusqlite::{params_from_iter, Connection, Result}; +/// use std::collections::BTreeSet; +/// +/// fn query(conn: &Connection, ids: &BTreeSet) -> Result<()> { +/// assert_eq!(ids.len(), 3, "Unrealistic sample code"); +/// +/// let mut stmt = conn.prepare("SELECT * FROM users WHERE id IN (?, ?, ?)")?; +/// let _rows = stmt.query(params_from_iter(ids.iter()))?; +/// +/// // use _rows... +/// Ok(()) +/// } +/// ``` +/// +/// ## Realistic use case +/// +/// Here's how you'd use `ParamsFromIter` to call [`Statement::exists`] with a +/// dynamic number of parameters. +/// +/// ```rust,no_run +/// use rusqlite::{Connection, Result}; +/// +/// pub fn any_active_users(conn: &Connection, usernames: &[String]) -> Result { +/// if usernames.is_empty() { +/// return Ok(false); +/// } +/// +/// // Note: `repeat_vars` never returns anything attacker-controlled, so +/// // it's fine to use it in a dynamically-built SQL string. +/// let vars = repeat_vars(usernames.len()); +/// +/// let sql = format!( +/// // In practice this would probably be better as an `EXISTS` query. +/// "SELECT 1 FROM user WHERE is_active AND name IN ({}) LIMIT 1", +/// vars, +/// ); +/// let mut stmt = conn.prepare(&sql)?; +/// stmt.exists(rusqlite::params_from_iter(usernames)) +/// } +/// +/// // Helper function to return a comma-separated sequence of `?`. +/// // - `repeat_vars(0) => panic!(...)` +/// // - `repeat_vars(1) => "?"` +/// // - `repeat_vars(2) => "?,?"` +/// // - `repeat_vars(3) => "?,?,?"` +/// // - ... +/// fn repeat_vars(count: usize) -> String { +/// assert_ne!(count, 0); +/// let mut s = "?,".repeat(count); +/// // Remove trailing comma +/// s.pop(); +/// s +/// } +/// ``` +/// +/// That is fairly complex, and even so would need even more work to be fully +/// production-ready: +/// +/// - production code should ensure `usernames` isn't so large that it will +/// surpass [`conn.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)`][limits]), +/// chunking if too large. (Note that the limits api requires rusqlite to have +/// the "limits" feature). +/// +/// - `repeat_vars` can be implemented in a way that avoids needing to allocate +/// a String. +/// +/// - Etc... +/// +/// [limits]: crate::Connection::limit +/// +/// This complexity reflects the fact that `ParamsFromIter` is mainly intended +/// for advanced use cases — most of the time you should know how many +/// parameters you have statically (and if you don't, you're either doing +/// something tricky, or should take a moment to think about the design). +#[derive(Clone, Debug)] +pub struct ParamsFromIter(I); + +/// Constructor function for a [`ParamsFromIter`]. See its documentation for +/// more. +#[inline] +pub fn params_from_iter(iter: I) -> ParamsFromIter +where + I: IntoIterator, + I::Item: ToSql, +{ + ParamsFromIter(iter) +} + +impl Sealed for ParamsFromIter +where + I: IntoIterator, + I::Item: ToSql, +{ +} + +impl Params for ParamsFromIter +where + I: IntoIterator, + I::Item: ToSql, +{ + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters(self.0) + } +} diff --git a/src/pragma.rs b/src/pragma.rs new file mode 100644 index 0000000..673478a --- /dev/null +++ b/src/pragma.rs @@ -0,0 +1,459 @@ +//! Pragma helpers + +use std::ops::Deref; + +use crate::error::Error; +use crate::ffi; +use crate::types::{ToSql, ToSqlOutput, ValueRef}; +use crate::{Connection, DatabaseName, Result, Row}; + +pub struct Sql { + buf: String, +} + +impl Sql { + pub fn new() -> Sql { + Sql { buf: String::new() } + } + + pub fn push_pragma( + &mut self, + schema_name: Option>, + pragma_name: &str, + ) -> Result<()> { + self.push_keyword("PRAGMA")?; + self.push_space(); + if let Some(schema_name) = schema_name { + self.push_schema_name(schema_name); + self.push_dot(); + } + self.push_keyword(pragma_name) + } + + pub fn push_keyword(&mut self, keyword: &str) -> Result<()> { + if !keyword.is_empty() && is_identifier(keyword) { + self.buf.push_str(keyword); + Ok(()) + } else { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Invalid keyword \"{}\"", keyword)), + )) + } + } + + pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) { + match schema_name { + DatabaseName::Main => self.buf.push_str("main"), + DatabaseName::Temp => self.buf.push_str("temp"), + DatabaseName::Attached(s) => self.push_identifier(s), + }; + } + + pub fn push_identifier(&mut self, s: &str) { + if is_identifier(s) { + self.buf.push_str(s); + } else { + self.wrap_and_escape(s, '"'); + } + } + + pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> { + let value = value.to_sql()?; + let value = match value { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + }; + match value { + ValueRef::Integer(i) => { + self.push_int(i); + } + ValueRef::Real(r) => { + self.push_real(r); + } + ValueRef::Text(s) => { + let s = std::str::from_utf8(s)?; + self.push_string_literal(s); + } + _ => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + }; + Ok(()) + } + + pub fn push_string_literal(&mut self, s: &str) { + self.wrap_and_escape(s, '\''); + } + + pub fn push_int(&mut self, i: i64) { + self.buf.push_str(&i.to_string()); + } + + pub fn push_real(&mut self, f: f64) { + self.buf.push_str(&f.to_string()); + } + + pub fn push_space(&mut self) { + self.buf.push(' '); + } + + pub fn push_dot(&mut self) { + self.buf.push('.'); + } + + pub fn push_equal_sign(&mut self) { + self.buf.push('='); + } + + pub fn open_brace(&mut self) { + self.buf.push('('); + } + + pub fn close_brace(&mut self) { + self.buf.push(')'); + } + + pub fn as_str(&self) -> &str { + &self.buf + } + + fn wrap_and_escape(&mut self, s: &str, quote: char) { + self.buf.push(quote); + let chars = s.chars(); + for ch in chars { + // escape `quote` by doubling it + if ch == quote { + self.buf.push(ch); + } + self.buf.push(ch); + } + self.buf.push(quote); + } +} + +impl Deref for Sql { + type Target = str; + + fn deref(&self) -> &str { + self.as_str() + } +} + +impl Connection { + /// Query the current value of `pragma_name`. + /// + /// Some pragmas will return multiple rows/values which cannot be retrieved + /// with this method. + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT user_version FROM pragma_user_version;` + pub fn pragma_query_value( + &self, + schema_name: Option>, + pragma_name: &str, + f: F, + ) -> Result + where + F: FnOnce(&Row<'_>) -> Result, + { + let mut query = Sql::new(); + query.push_pragma(schema_name, pragma_name)?; + self.query_row(&query, [], f) + } + + /// Query the current rows/values of `pragma_name`. + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT * FROM pragma_collation_list;` + pub fn pragma_query( + &self, + schema_name: Option>, + pragma_name: &str, + mut f: F, + ) -> Result<()> + where + F: FnMut(&Row<'_>) -> Result<()>, + { + let mut query = Sql::new(); + query.push_pragma(schema_name, pragma_name)?; + let mut stmt = self.prepare(&query)?; + let mut rows = stmt.query([])?; + while let Some(result_row) = rows.next()? { + let row = result_row; + f(row)?; + } + Ok(()) + } + + /// Query the current value(s) of `pragma_name` associated to + /// `pragma_value`. + /// + /// This method can be used with query-only pragmas which need an argument + /// (e.g. `table_info('one_tbl')`) or pragmas which returns value(s) + /// (e.g. `integrity_check`). + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT * FROM pragma_table_info(?);` + pub fn pragma( + &self, + schema_name: Option>, + pragma_name: &str, + pragma_value: V, + mut f: F, + ) -> Result<()> + where + F: FnMut(&Row<'_>) -> Result<()>, + V: ToSql, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.open_brace(); + sql.push_value(&pragma_value)?; + sql.close_brace(); + let mut stmt = self.prepare(&sql)?; + let mut rows = stmt.query([])?; + while let Some(result_row) = rows.next()? { + let row = result_row; + f(row)?; + } + Ok(()) + } + + /// Set a new value to `pragma_name`. + /// + /// Some pragmas will return the updated value which cannot be retrieved + /// with this method. + pub fn pragma_update( + &self, + schema_name: Option>, + pragma_name: &str, + pragma_value: V, + ) -> Result<()> + where + V: ToSql, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.push_equal_sign(); + sql.push_value(&pragma_value)?; + self.execute_batch(&sql) + } + + /// Set a new value to `pragma_name` and return the updated value. + /// + /// Only few pragmas automatically return the updated value. + pub fn pragma_update_and_check( + &self, + schema_name: Option>, + pragma_name: &str, + pragma_value: V, + f: F, + ) -> Result + where + F: FnOnce(&Row<'_>) -> Result, + V: ToSql, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.push_equal_sign(); + sql.push_value(&pragma_value)?; + self.query_row(&sql, [], f) + } +} + +fn is_identifier(s: &str) -> bool { + let chars = s.char_indices(); + for (i, ch) in chars { + if i == 0 { + if !is_identifier_start(ch) { + return false; + } + } else if !is_identifier_continue(ch) { + return false; + } + } + true +} + +fn is_identifier_start(c: char) -> bool { + ('A'..='Z').contains(&c) || c == '_' || ('a'..='z').contains(&c) || c > '\x7F' +} + +fn is_identifier_continue(c: char) -> bool { + c == '$' + || ('0'..='9').contains(&c) + || ('A'..='Z').contains(&c) + || c == '_' + || ('a'..='z').contains(&c) + || c > '\x7F' +} + +#[cfg(test)] +mod test { + use super::Sql; + use crate::pragma; + use crate::{Connection, DatabaseName, Result}; + + #[test] + fn pragma_query_value() -> Result<()> { + let db = Connection::open_in_memory()?; + let user_version: i32 = db.pragma_query_value(None, "user_version", |row| row.get(0))?; + assert_eq!(0, user_version); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn pragma_func_query_value() -> Result<()> { + let db = Connection::open_in_memory()?; + let user_version: i32 = + db.query_row("SELECT user_version FROM pragma_user_version", [], |row| { + row.get(0) + })?; + assert_eq!(0, user_version); + Ok(()) + } + + #[test] + fn pragma_query_no_schema() -> Result<()> { + let db = Connection::open_in_memory()?; + let mut user_version = -1; + db.pragma_query(None, "user_version", |row| { + user_version = row.get(0)?; + Ok(()) + })?; + assert_eq!(0, user_version); + Ok(()) + } + + #[test] + fn pragma_query_with_schema() -> Result<()> { + let db = Connection::open_in_memory()?; + let mut user_version = -1; + db.pragma_query(Some(DatabaseName::Main), "user_version", |row| { + user_version = row.get(0)?; + Ok(()) + })?; + assert_eq!(0, user_version); + Ok(()) + } + + #[test] + fn pragma() -> Result<()> { + let db = Connection::open_in_memory()?; + let mut columns = Vec::new(); + db.pragma(None, "table_info", &"sqlite_master", |row| { + let column: String = row.get(1)?; + columns.push(column); + Ok(()) + })?; + assert_eq!(5, columns.len()); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn pragma_func() -> Result<()> { + let db = Connection::open_in_memory()?; + let mut table_info = db.prepare("SELECT * FROM pragma_table_info(?)")?; + let mut columns = Vec::new(); + let mut rows = table_info.query(["sqlite_master"])?; + + while let Some(row) = rows.next()? { + let row = row; + let column: String = row.get(1)?; + columns.push(column); + } + assert_eq!(5, columns.len()); + Ok(()) + } + + #[test] + fn pragma_update() -> Result<()> { + let db = Connection::open_in_memory()?; + db.pragma_update(None, "user_version", 1) + } + + #[test] + fn pragma_update_and_check() -> Result<()> { + let db = Connection::open_in_memory()?; + let journal_mode: String = + db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get(0))?; + assert!( + journal_mode == "off" || journal_mode == "memory", + "mode: {:?}", + journal_mode, + ); + // Sanity checks to ensure the move to a generic `ToSql` wasn't breaking + let mode = db + .pragma_update_and_check(None, "journal_mode", &"OFF", |row| row.get::<_, String>(0))?; + assert!(mode == "off" || mode == "memory", "mode: {:?}", mode); + + let param: &dyn crate::ToSql = &"OFF"; + let mode = + db.pragma_update_and_check(None, "journal_mode", param, |row| row.get::<_, String>(0))?; + assert!(mode == "off" || mode == "memory", "mode: {:?}", mode); + Ok(()) + } + + #[test] + fn is_identifier() { + assert!(pragma::is_identifier("full")); + assert!(pragma::is_identifier("r2d2")); + assert!(!pragma::is_identifier("sp ce")); + assert!(!pragma::is_identifier("semi;colon")); + } + + #[test] + fn double_quote() { + let mut sql = Sql::new(); + sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#)); + assert_eq!(r#""schema"";--""#, sql.as_str()); + } + + #[test] + fn wrap_and_escape() { + let mut sql = Sql::new(); + sql.push_string_literal("value'; --"); + assert_eq!("'value''; --'", sql.as_str()); + } + + #[test] + fn locking_mode() -> Result<()> { + let db = Connection::open_in_memory()?; + let r = db.pragma_update(None, "locking_mode", &"exclusive"); + if cfg!(feature = "extra_check") { + r.unwrap_err(); + } else { + r?; + } + Ok(()) + } +} diff --git a/src/raw_statement.rs b/src/raw_statement.rs new file mode 100644 index 0000000..f057761 --- /dev/null +++ b/src/raw_statement.rs @@ -0,0 +1,241 @@ +use super::ffi; +use super::StatementStatus; +use crate::util::ParamIndexCache; +#[cfg(feature = "modern_sqlite")] +use crate::util::SqliteMallocString; +use std::ffi::CStr; +use std::os::raw::c_int; +use std::ptr; +use std::sync::Arc; + +// Private newtype for raw sqlite3_stmts that finalize themselves when dropped. +#[derive(Debug)] +pub struct RawStatement { + ptr: *mut ffi::sqlite3_stmt, + tail: usize, + // Cached indices of named parameters, computed on the fly. + cache: ParamIndexCache, + // Cached SQL (trimmed) that we use as the key when we're in the statement + // cache. This is None for statements which didn't come from the statement + // cache. + // + // This is probably the same as `self.sql()` in most cases, but we don't + // care either way -- It's a better cache key as it is anyway since it's the + // actual source we got from rust. + // + // One example of a case where the result of `sqlite_sql` and the value in + // `statement_cache_key` might differ is if the statement has a `tail`. + statement_cache_key: Option>, +} + +impl RawStatement { + #[inline] + pub unsafe fn new(stmt: *mut ffi::sqlite3_stmt, tail: usize) -> RawStatement { + RawStatement { + ptr: stmt, + tail, + cache: ParamIndexCache::default(), + statement_cache_key: None, + } + } + + #[inline] + pub fn is_null(&self) -> bool { + self.ptr.is_null() + } + + #[inline] + pub(crate) fn set_statement_cache_key(&mut self, p: impl Into>) { + self.statement_cache_key = Some(p.into()); + } + + #[inline] + pub(crate) fn statement_cache_key(&self) -> Option> { + self.statement_cache_key.clone() + } + + #[inline] + pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { + self.ptr + } + + #[inline] + pub fn column_count(&self) -> usize { + // Note: Can't cache this as it changes if the schema is altered. + unsafe { ffi::sqlite3_column_count(self.ptr) as usize } + } + + #[inline] + pub fn column_type(&self, idx: usize) -> c_int { + unsafe { ffi::sqlite3_column_type(self.ptr, idx as c_int) } + } + + #[inline] + #[cfg(feature = "column_decltype")] + pub fn column_decltype(&self, idx: usize) -> Option<&CStr> { + unsafe { + let decltype = ffi::sqlite3_column_decltype(self.ptr, idx as c_int); + if decltype.is_null() { + None + } else { + Some(CStr::from_ptr(decltype)) + } + } + } + + #[inline] + pub fn column_name(&self, idx: usize) -> Option<&CStr> { + let idx = idx as c_int; + if idx < 0 || idx >= self.column_count() as c_int { + return None; + } + unsafe { + let ptr = ffi::sqlite3_column_name(self.ptr, idx); + // If ptr is null here, it's an OOM, so there's probably nothing + // meaningful we can do. Just assert instead of returning None. + assert!( + !ptr.is_null(), + "Null pointer from sqlite3_column_name: Out of memory?" + ); + Some(CStr::from_ptr(ptr)) + } + } + + #[inline] + #[cfg(not(feature = "unlock_notify"))] + pub fn step(&self) -> c_int { + unsafe { ffi::sqlite3_step(self.ptr) } + } + + #[cfg(feature = "unlock_notify")] + pub fn step(&self) -> c_int { + use crate::unlock_notify; + let mut db = ptr::null_mut::(); + loop { + unsafe { + let mut rc = ffi::sqlite3_step(self.ptr); + // Bail out early for success and errors unrelated to locking. We + // still need check `is_locked` after this, but checking now lets us + // avoid one or two (admittedly cheap) calls into SQLite that we + // don't need to make. + if (rc & 0xff) != ffi::SQLITE_LOCKED { + break rc; + } + if db.is_null() { + db = ffi::sqlite3_db_handle(self.ptr); + } + if !unlock_notify::is_locked(db, rc) { + break rc; + } + rc = unlock_notify::wait_for_unlock_notify(db); + if rc != ffi::SQLITE_OK { + break rc; + } + self.reset(); + } + } + } + + #[inline] + pub fn reset(&self) -> c_int { + unsafe { ffi::sqlite3_reset(self.ptr) } + } + + #[inline] + pub fn bind_parameter_count(&self) -> usize { + unsafe { ffi::sqlite3_bind_parameter_count(self.ptr) as usize } + } + + #[inline] + pub fn bind_parameter_index(&self, name: &str) -> Option { + self.cache.get_or_insert_with(name, |param_cstr| { + let r = unsafe { ffi::sqlite3_bind_parameter_index(self.ptr, param_cstr.as_ptr()) }; + match r { + 0 => None, + i => Some(i as usize), + } + }) + } + + #[inline] + pub fn bind_parameter_name(&self, index: i32) -> Option<&CStr> { + unsafe { + let name = ffi::sqlite3_bind_parameter_name(self.ptr, index); + if name.is_null() { + None + } else { + Some(CStr::from_ptr(name)) + } + } + } + + #[inline] + pub fn clear_bindings(&self) -> c_int { + unsafe { ffi::sqlite3_clear_bindings(self.ptr) } + } + + #[inline] + pub fn sql(&self) -> Option<&CStr> { + if self.ptr.is_null() { + None + } else { + Some(unsafe { CStr::from_ptr(ffi::sqlite3_sql(self.ptr)) }) + } + } + + #[inline] + pub fn finalize(mut self) -> c_int { + self.finalize_() + } + + #[inline] + fn finalize_(&mut self) -> c_int { + let r = unsafe { ffi::sqlite3_finalize(self.ptr) }; + self.ptr = ptr::null_mut(); + r + } + + // does not work for PRAGMA + #[inline] + #[cfg(all(feature = "extra_check", feature = "modern_sqlite"))] // 3.7.4 + pub fn readonly(&self) -> bool { + unsafe { ffi::sqlite3_stmt_readonly(self.ptr) != 0 } + } + + #[inline] + #[cfg(feature = "modern_sqlite")] // 3.14.0 + pub(crate) fn expanded_sql(&self) -> Option { + unsafe { SqliteMallocString::from_raw(ffi::sqlite3_expanded_sql(self.ptr)) } + } + + #[inline] + pub fn get_status(&self, status: StatementStatus, reset: bool) -> i32 { + assert!(!self.ptr.is_null()); + unsafe { ffi::sqlite3_stmt_status(self.ptr, status as i32, reset as i32) } + } + + #[inline] + #[cfg(feature = "extra_check")] + pub fn has_tail(&self) -> bool { + self.tail != 0 + } + + #[inline] + pub fn tail(&self) -> usize { + self.tail + } + + #[inline] + #[cfg(feature = "modern_sqlite")] // 3.28.0 + pub fn is_explain(&self) -> i32 { + unsafe { ffi::sqlite3_stmt_isexplain(self.ptr) } + } + + // TODO sqlite3_normalized_sql (https://sqlite.org/c3ref/expanded_sql.html) // 3.27.0 + SQLITE_ENABLE_NORMALIZE +} + +impl Drop for RawStatement { + fn drop(&mut self) { + self.finalize_(); + } +} diff --git a/src/row.rs b/src/row.rs new file mode 100644 index 0000000..221905a --- /dev/null +++ b/src/row.rs @@ -0,0 +1,559 @@ +use fallible_iterator::FallibleIterator; +use fallible_streaming_iterator::FallibleStreamingIterator; +use std::convert; + +use super::{Error, Result, Statement}; +use crate::types::{FromSql, FromSqlError, ValueRef}; + +/// An handle for the resulting rows of a query. +#[must_use = "Rows is lazy and will do nothing unless consumed"] +pub struct Rows<'stmt> { + pub(crate) stmt: Option<&'stmt Statement<'stmt>>, + row: Option>, +} + +impl<'stmt> Rows<'stmt> { + #[inline] + fn reset(&mut self) { + if let Some(stmt) = self.stmt.take() { + stmt.reset(); + } + } + + /// Attempt to get the next row from the query. Returns `Ok(Some(Row))` if + /// there is another row, `Err(...)` if there was an error + /// getting the next row, and `Ok(None)` if all rows have been retrieved. + /// + /// ## Note + /// + /// This interface is not compatible with Rust's `Iterator` trait, because + /// the lifetime of the returned row is tied to the lifetime of `self`. + /// This is a fallible "streaming iterator". For a more natural interface, + /// consider using [`query_map`](crate::Statement::query_map) or + /// [`query_and_then`](crate::Statement::query_and_then) instead, which + /// return types that implement `Iterator`. + #[allow(clippy::should_implement_trait)] // cannot implement Iterator + #[inline] + pub fn next(&mut self) -> Result>> { + self.advance()?; + Ok((*self).get()) + } + + /// Map over this `Rows`, converting it to a [`Map`], which + /// implements `FallibleIterator`. + /// ```rust,no_run + /// use fallible_iterator::FallibleIterator; + /// # use rusqlite::{Result, Statement}; + /// fn query(stmt: &mut Statement) -> Result> { + /// let rows = stmt.query([])?; + /// rows.map(|r| r.get(0)).collect() + /// } + /// ``` + // FIXME Hide FallibleStreamingIterator::map + #[inline] + pub fn map(self, f: F) -> Map<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result, + { + Map { rows: self, f } + } + + /// Map over this `Rows`, converting it to a [`MappedRows`], which + /// implements `Iterator`. + #[inline] + pub fn mapped(self, f: F) -> MappedRows<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result, + { + MappedRows { rows: self, map: f } + } + + /// Map over this `Rows` with a fallible function, converting it to a + /// [`AndThenRows`], which implements `Iterator` (instead of + /// `FallibleStreamingIterator`). + #[inline] + pub fn and_then(self, f: F) -> AndThenRows<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result, + { + AndThenRows { rows: self, map: f } + } + + /// Give access to the underlying statement + #[must_use] + pub fn as_ref(&self) -> Option<&Statement<'stmt>> { + self.stmt + } +} + +impl<'stmt> Rows<'stmt> { + #[inline] + pub(crate) fn new(stmt: &'stmt Statement<'stmt>) -> Rows<'stmt> { + Rows { + stmt: Some(stmt), + row: None, + } + } + + #[inline] + pub(crate) fn get_expected_row(&mut self) -> Result<&Row<'stmt>> { + match self.next()? { + Some(row) => Ok(row), + None => Err(Error::QueryReturnedNoRows), + } + } +} + +impl Drop for Rows<'_> { + #[inline] + fn drop(&mut self) { + self.reset(); + } +} + +/// `F` is used to transform the _streaming_ iterator into a _fallible_ +/// iterator. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct Map<'stmt, F> { + rows: Rows<'stmt>, + f: F, +} + +impl FallibleIterator for Map<'_, F> +where + F: FnMut(&Row<'_>) -> Result, +{ + type Error = Error; + type Item = B; + + #[inline] + fn next(&mut self) -> Result> { + match self.rows.next()? { + Some(v) => Ok(Some((self.f)(v)?)), + None => Ok(None), + } + } +} + +/// An iterator over the mapped resulting rows of a query. +/// +/// `F` is used to transform the _streaming_ iterator into a _standard_ +/// iterator. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct MappedRows<'stmt, F> { + rows: Rows<'stmt>, + map: F, +} + +impl Iterator for MappedRows<'_, F> +where + F: FnMut(&Row<'_>) -> Result, +{ + type Item = Result; + + #[inline] + fn next(&mut self) -> Option> { + let map = &mut self.map; + self.rows + .next() + .transpose() + .map(|row_result| row_result.and_then(map)) + } +} + +/// An iterator over the mapped resulting rows of a query, with an Error type +/// unifying with Error. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct AndThenRows<'stmt, F> { + rows: Rows<'stmt>, + map: F, +} + +impl Iterator for AndThenRows<'_, F> +where + E: From, + F: FnMut(&Row<'_>) -> Result, +{ + type Item = Result; + + #[inline] + fn next(&mut self) -> Option { + let map = &mut self.map; + self.rows + .next() + .transpose() + .map(|row_result| row_result.map_err(E::from).and_then(map)) + } +} + +/// `FallibleStreamingIterator` differs from the standard library's `Iterator` +/// in two ways: +/// * each call to `next` (`sqlite3_step`) can fail. +/// * returned `Row` is valid until `next` is called again or `Statement` is +/// reset or finalized. +/// +/// While these iterators cannot be used with Rust `for` loops, `while let` +/// loops offer a similar level of ergonomics: +/// ```rust,no_run +/// # use rusqlite::{Result, Statement}; +/// fn query(stmt: &mut Statement) -> Result<()> { +/// let mut rows = stmt.query([])?; +/// while let Some(row) = rows.next()? { +/// // scan columns value +/// } +/// Ok(()) +/// } +/// ``` +impl<'stmt> FallibleStreamingIterator for Rows<'stmt> { + type Error = Error; + type Item = Row<'stmt>; + + #[inline] + fn advance(&mut self) -> Result<()> { + if let Some(stmt) = self.stmt { + match stmt.step() { + Ok(true) => { + self.row = Some(Row { stmt }); + Ok(()) + } + Ok(false) => { + self.reset(); + self.row = None; + Ok(()) + } + Err(e) => { + self.reset(); + self.row = None; + Err(e) + } + } + } else { + self.row = None; + Ok(()) + } + } + + #[inline] + fn get(&self) -> Option<&Row<'stmt>> { + self.row.as_ref() + } +} + +/// A single result row of a query. +pub struct Row<'stmt> { + pub(crate) stmt: &'stmt Statement<'stmt>, +} + +impl<'stmt> Row<'stmt> { + /// Get the value of a particular column of the result row. + /// + /// ## Failure + /// + /// Panics if calling [`row.get(idx)`](Row::get) would return an error, + /// including: + /// + /// * If the underlying SQLite column type is not a valid type as a source + /// for `T` + /// * If the underlying SQLite integral value is outside the range + /// representable by `T` + /// * If `idx` is outside the range of columns in the returned query + pub fn get_unwrap(&self, idx: I) -> T { + self.get(idx).unwrap() + } + + /// Get the value of a particular column of the result row. + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnType` if the underlying SQLite column + /// type is not a valid type as a source for `T`. + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// Returns an `Error::InvalidColumnName` if `idx` is not a valid column + /// name for this row. + /// + /// If the result type is i128 (which requires the `i128_blob` feature to be + /// enabled), and the underlying SQLite column is a blob whose size is not + /// 16 bytes, `Error::InvalidColumnType` will also be returned. + pub fn get(&self, idx: I) -> Result { + let idx = idx.idx(self.stmt)?; + let value = self.stmt.value_ref(idx); + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => Error::InvalidColumnType( + idx, + self.stmt.column_name_unwrap(idx).into(), + value.data_type(), + ), + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx, value.data_type(), err) + } + FromSqlError::InvalidBlobSize { .. } => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + }) + } + + /// Get the value of a particular column of the result row as a `ValueRef`, + /// allowing data to be read out of a row without copying. + /// + /// This `ValueRef` is valid only as long as this Row, which is enforced by + /// it's lifetime. This means that while this method is completely safe, + /// it can be somewhat difficult to use, and most callers will be better + /// served by [`get`](Row::get) or [`get_unwrap`](Row::get_unwrap). + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// Returns an `Error::InvalidColumnName` if `idx` is not a valid column + /// name for this row. + pub fn get_ref(&self, idx: I) -> Result> { + let idx = idx.idx(self.stmt)?; + // Narrowing from `ValueRef<'stmt>` (which `self.stmt.value_ref(idx)` + // returns) to `ValueRef<'a>` is needed because it's only valid until + // the next call to sqlite3_step. + let val_ref = self.stmt.value_ref(idx); + Ok(val_ref) + } + + /// Get the value of a particular column of the result row as a `ValueRef`, + /// allowing data to be read out of a row without copying. + /// + /// This `ValueRef` is valid only as long as this Row, which is enforced by + /// it's lifetime. This means that while this method is completely safe, + /// it can be difficult to use, and most callers will be better served by + /// [`get`](Row::get) or [`get_unwrap`](Row::get_unwrap). + /// + /// ## Failure + /// + /// Panics if calling [`row.get_ref(idx)`](Row::get_ref) would return an + /// error, including: + /// + /// * If `idx` is outside the range of columns in the returned query. + /// * If `idx` is not a valid column name for this row. + pub fn get_ref_unwrap(&self, idx: I) -> ValueRef<'_> { + self.get_ref(idx).unwrap() + } + + /// Renamed to [`get_ref`](Row::get_ref). + #[deprecated = "Use [`get_ref`](Row::get_ref) instead."] + #[inline] + pub fn get_raw_checked(&self, idx: I) -> Result> { + self.get_ref(idx) + } + + /// Renamed to [`get_ref_unwrap`](Row::get_ref_unwrap). + #[deprecated = "Use [`get_ref_unwrap`](Row::get_ref_unwrap) instead."] + #[inline] + pub fn get_raw(&self, idx: I) -> ValueRef<'_> { + self.get_ref_unwrap(idx) + } +} + +impl<'stmt> AsRef> for Row<'stmt> { + fn as_ref(&self) -> &Statement<'stmt> { + self.stmt + } +} + +mod sealed { + /// This trait exists just to ensure that the only impls of `trait Params` + /// that are allowed are ones in this crate. + pub trait Sealed {} + impl Sealed for usize {} + impl Sealed for &str {} +} + +/// A trait implemented by types that can index into columns of a row. +/// +/// It is only implemented for `usize` and `&str`. +pub trait RowIndex: sealed::Sealed { + /// Returns the index of the appropriate column, or `None` if no such + /// column exists. + fn idx(&self, stmt: &Statement<'_>) -> Result; +} + +impl RowIndex for usize { + #[inline] + fn idx(&self, stmt: &Statement<'_>) -> Result { + if *self >= stmt.column_count() { + Err(Error::InvalidColumnIndex(*self)) + } else { + Ok(*self) + } + } +} + +impl RowIndex for &'_ str { + #[inline] + fn idx(&self, stmt: &Statement<'_>) -> Result { + stmt.column_index(*self) + } +} + +macro_rules! tuple_try_from_row { + ($($field:ident),*) => { + impl<'a, $($field,)*> convert::TryFrom<&'a Row<'a>> for ($($field,)*) where $($field: FromSql,)* { + type Error = crate::Error; + + // we end with index += 1, which rustc warns about + // unused_variables and unused_mut are allowed for () + #[allow(unused_assignments, unused_variables, unused_mut)] + fn try_from(row: &'a Row<'a>) -> Result { + let mut index = 0; + $( + #[allow(non_snake_case)] + let $field = row.get::<_, $field>(index)?; + index += 1; + )* + Ok(($($field,)*)) + } + } + } +} + +macro_rules! tuples_try_from_row { + () => { + // not very useful, but maybe some other macro users will find this helpful + tuple_try_from_row!(); + }; + ($first:ident $(, $remaining:ident)*) => { + tuple_try_from_row!($first $(, $remaining)*); + tuples_try_from_row!($($remaining),*); + }; +} + +tuples_try_from_row!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P); + +#[cfg(test)] +mod tests { + #![allow(clippy::redundant_closure)] // false positives due to lifetime issues; clippy issue #5594 + use crate::{Connection, Result}; + + #[test] + fn test_try_from_row_for_tuple_1() -> Result<()> { + use crate::ToSql; + use std::convert::TryFrom; + + let conn = Connection::open_in_memory()?; + conn.execute( + "CREATE TABLE test (a INTEGER)", + crate::params_from_iter(std::iter::empty::<&dyn ToSql>()), + )?; + conn.execute("INSERT INTO test VALUES (42)", [])?; + let val = conn.query_row("SELECT a FROM test", [], |row| <(u32,)>::try_from(row))?; + assert_eq!(val, (42,)); + let fail = conn.query_row("SELECT a FROM test", [], |row| <(u32, u32)>::try_from(row)); + assert!(fail.is_err()); + Ok(()) + } + + #[test] + fn test_try_from_row_for_tuple_2() -> Result<()> { + use std::convert::TryFrom; + + let conn = Connection::open_in_memory()?; + conn.execute("CREATE TABLE test (a INTEGER, b INTEGER)", [])?; + conn.execute("INSERT INTO test VALUES (42, 47)", [])?; + let val = conn.query_row("SELECT a, b FROM test", [], |row| { + <(u32, u32)>::try_from(row) + })?; + assert_eq!(val, (42, 47)); + let fail = conn.query_row("SELECT a, b FROM test", [], |row| { + <(u32, u32, u32)>::try_from(row) + }); + assert!(fail.is_err()); + Ok(()) + } + + #[test] + fn test_try_from_row_for_tuple_16() -> Result<()> { + use std::convert::TryFrom; + + let create_table = "CREATE TABLE test ( + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER, + e INTEGER, + f INTEGER, + g INTEGER, + h INTEGER, + i INTEGER, + j INTEGER, + k INTEGER, + l INTEGER, + m INTEGER, + n INTEGER, + o INTEGER, + p INTEGER + )"; + + let insert_values = "INSERT INTO test VALUES ( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15 + )"; + + type BigTuple = ( + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + ); + + let conn = Connection::open_in_memory()?; + conn.execute(create_table, [])?; + conn.execute(insert_values, [])?; + let val = conn.query_row("SELECT * FROM test", [], |row| BigTuple::try_from(row))?; + // Debug is not implemented for tuples of 16 + assert_eq!(val.0, 0); + assert_eq!(val.1, 1); + assert_eq!(val.2, 2); + assert_eq!(val.3, 3); + assert_eq!(val.4, 4); + assert_eq!(val.5, 5); + assert_eq!(val.6, 6); + assert_eq!(val.7, 7); + assert_eq!(val.8, 8); + assert_eq!(val.9, 9); + assert_eq!(val.10, 10); + assert_eq!(val.11, 11); + assert_eq!(val.12, 12); + assert_eq!(val.13, 13); + assert_eq!(val.14, 14); + assert_eq!(val.15, 15); + + // We don't test one bigger because it's unimplemented + Ok(()) + } +} diff --git a/src/session.rs b/src/session.rs new file mode 100644 index 0000000..f8aa764 --- /dev/null +++ b/src/session.rs @@ -0,0 +1,938 @@ +//! [Session Extension](https://sqlite.org/sessionintro.html) +#![allow(non_camel_case_types)] + +use std::ffi::CStr; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int, c_uchar, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe}; +use std::ptr; +use std::slice::{from_raw_parts, from_raw_parts_mut}; + +use fallible_streaming_iterator::FallibleStreamingIterator; + +use crate::error::{check, error_from_sqlite_code}; +use crate::ffi; +use crate::hooks::Action; +use crate::types::ValueRef; +use crate::{errmsg_to_string, str_to_cstring, Connection, DatabaseName, Result}; + +// https://sqlite.org/session.html + +/// An instance of this object is a session that can be +/// used to record changes to a database. +pub struct Session<'conn> { + phantom: PhantomData<&'conn Connection>, + s: *mut ffi::sqlite3_session, + filter: Option bool>>, +} + +impl Session<'_> { + /// Create a new session object + #[inline] + pub fn new(db: &Connection) -> Result> { + Session::new_with_name(db, DatabaseName::Main) + } + + /// Create a new session object + #[inline] + pub fn new_with_name<'conn>( + db: &'conn Connection, + name: DatabaseName<'_>, + ) -> Result> { + let name = name.as_cstring()?; + + let db = db.db.borrow_mut().db; + + let mut s: *mut ffi::sqlite3_session = ptr::null_mut(); + check(unsafe { ffi::sqlite3session_create(db, name.as_ptr(), &mut s) })?; + + Ok(Session { + phantom: PhantomData, + s, + filter: None, + }) + } + + /// Set a table filter + pub fn table_filter(&mut self, filter: Option) + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure( + p_arg: *mut c_void, + tbl_str: *const c_char, + ) -> c_int + where + F: Fn(&str) -> bool + RefUnwindSafe, + { + use std::str; + + let boxed_filter: *mut F = p_arg as *mut F; + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + if let Ok(true) = + catch_unwind(|| (*boxed_filter)(tbl_name.expect("non-utf8 table name"))) + { + 1 + } else { + 0 + } + } + + match filter { + Some(filter) => { + let boxed_filter = Box::new(filter); + unsafe { + ffi::sqlite3session_table_filter( + self.s, + Some(call_boxed_closure::), + &*boxed_filter as *const F as *mut _, + ); + } + self.filter = Some(boxed_filter); + } + _ => { + unsafe { ffi::sqlite3session_table_filter(self.s, None, ptr::null_mut()) } + self.filter = None; + } + }; + } + + /// Attach a table. `None` means all tables. + pub fn attach(&mut self, table: Option<&str>) -> Result<()> { + let table = if let Some(table) = table { + Some(str_to_cstring(table)?) + } else { + None + }; + let table = table.as_ref().map(|s| s.as_ptr()).unwrap_or(ptr::null()); + check(unsafe { ffi::sqlite3session_attach(self.s, table) }) + } + + /// Generate a Changeset + pub fn changeset(&mut self) -> Result { + let mut n = 0; + let mut cs: *mut c_void = ptr::null_mut(); + check(unsafe { ffi::sqlite3session_changeset(self.s, &mut n, &mut cs) })?; + Ok(Changeset { cs, n }) + } + + /// Write the set of changes represented by this session to `output`. + #[inline] + pub fn changeset_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check(unsafe { + ffi::sqlite3session_changeset_strm( + self.s, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) + } + + /// Generate a Patchset + #[inline] + pub fn patchset(&mut self) -> Result { + let mut n = 0; + let mut ps: *mut c_void = ptr::null_mut(); + check(unsafe { ffi::sqlite3session_patchset(self.s, &mut n, &mut ps) })?; + // TODO Validate: same struct + Ok(Changeset { cs: ps, n }) + } + + /// Write the set of patches represented by this session to `output`. + #[inline] + pub fn patchset_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check(unsafe { + ffi::sqlite3session_patchset_strm( + self.s, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) + } + + /// Load the difference between tables. + pub fn diff(&mut self, from: DatabaseName<'_>, table: &str) -> Result<()> { + let from = from.as_cstring()?; + let table = str_to_cstring(table)?; + let table = table.as_ptr(); + unsafe { + let mut errmsg = ptr::null_mut(); + let r = + ffi::sqlite3session_diff(self.s, from.as_ptr(), table, &mut errmsg as *mut *mut _); + if r != ffi::SQLITE_OK { + let errmsg: *mut c_char = errmsg; + let message = errmsg_to_string(&*errmsg); + ffi::sqlite3_free(errmsg as *mut c_void); + return Err(error_from_sqlite_code(r, Some(message))); + } + } + Ok(()) + } + + /// Test if a changeset has recorded any changes + #[inline] + pub fn is_empty(&self) -> bool { + unsafe { ffi::sqlite3session_isempty(self.s) != 0 } + } + + /// Query the current state of the session + #[inline] + pub fn is_enabled(&self) -> bool { + unsafe { ffi::sqlite3session_enable(self.s, -1) != 0 } + } + + /// Enable or disable the recording of changes + #[inline] + pub fn set_enabled(&mut self, enabled: bool) { + unsafe { + ffi::sqlite3session_enable(self.s, if enabled { 1 } else { 0 }); + } + } + + /// Query the current state of the indirect flag + #[inline] + pub fn is_indirect(&self) -> bool { + unsafe { ffi::sqlite3session_indirect(self.s, -1) != 0 } + } + + /// Set or clear the indirect change flag + #[inline] + pub fn set_indirect(&mut self, indirect: bool) { + unsafe { + ffi::sqlite3session_indirect(self.s, if indirect { 1 } else { 0 }); + } + } +} + +impl Drop for Session<'_> { + #[inline] + fn drop(&mut self) { + if self.filter.is_some() { + self.table_filter(None:: bool>); + } + unsafe { ffi::sqlite3session_delete(self.s) }; + } +} + +/// Invert a changeset +#[inline] +pub fn invert_strm(input: &mut dyn Read, output: &mut dyn Write) -> Result<()> { + let input_ref = &input; + let output_ref = &output; + check(unsafe { + ffi::sqlite3changeset_invert_strm( + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) +} + +/// Combine two changesets +#[inline] +pub fn concat_strm( + input_a: &mut dyn Read, + input_b: &mut dyn Read, + output: &mut dyn Write, +) -> Result<()> { + let input_a_ref = &input_a; + let input_b_ref = &input_b; + let output_ref = &output; + check(unsafe { + ffi::sqlite3changeset_concat_strm( + Some(x_input), + input_a_ref as *const &mut dyn Read as *mut c_void, + Some(x_input), + input_b_ref as *const &mut dyn Read as *mut c_void, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) +} + +/// Changeset or Patchset +pub struct Changeset { + cs: *mut c_void, + n: c_int, +} + +impl Changeset { + /// Invert a changeset + #[inline] + pub fn invert(&self) -> Result { + let mut n = 0; + let mut cs = ptr::null_mut(); + check(unsafe { + ffi::sqlite3changeset_invert(self.n, self.cs, &mut n, &mut cs as *mut *mut _) + })?; + Ok(Changeset { cs, n }) + } + + /// Create an iterator to traverse a changeset + #[inline] + pub fn iter(&self) -> Result> { + let mut it = ptr::null_mut(); + check(unsafe { ffi::sqlite3changeset_start(&mut it as *mut *mut _, self.n, self.cs) })?; + Ok(ChangesetIter { + phantom: PhantomData, + it, + item: None, + }) + } + + /// Concatenate two changeset objects + #[inline] + pub fn concat(a: &Changeset, b: &Changeset) -> Result { + let mut n = 0; + let mut cs = ptr::null_mut(); + check(unsafe { + ffi::sqlite3changeset_concat(a.n, a.cs, b.n, b.cs, &mut n, &mut cs as *mut *mut _) + })?; + Ok(Changeset { cs, n }) + } +} + +impl Drop for Changeset { + #[inline] + fn drop(&mut self) { + unsafe { + ffi::sqlite3_free(self.cs); + } + } +} + +/// Cursor for iterating over the elements of a changeset +/// or patchset. +pub struct ChangesetIter<'changeset> { + phantom: PhantomData<&'changeset Changeset>, + it: *mut ffi::sqlite3_changeset_iter, + item: Option, +} + +impl ChangesetIter<'_> { + /// Create an iterator on `input` + #[inline] + pub fn start_strm<'input>(input: &&'input mut dyn Read) -> Result> { + let mut it = ptr::null_mut(); + check(unsafe { + ffi::sqlite3changeset_start_strm( + &mut it as *mut *mut _, + Some(x_input), + input as *const &mut dyn Read as *mut c_void, + ) + })?; + Ok(ChangesetIter { + phantom: PhantomData, + it, + item: None, + }) + } +} + +impl FallibleStreamingIterator for ChangesetIter<'_> { + type Error = crate::error::Error; + type Item = ChangesetItem; + + #[inline] + fn advance(&mut self) -> Result<()> { + let rc = unsafe { ffi::sqlite3changeset_next(self.it) }; + match rc { + ffi::SQLITE_ROW => { + self.item = Some(ChangesetItem { it: self.it }); + Ok(()) + } + ffi::SQLITE_DONE => { + self.item = None; + Ok(()) + } + code => Err(error_from_sqlite_code(code, None)), + } + } + + #[inline] + fn get(&self) -> Option<&ChangesetItem> { + self.item.as_ref() + } +} + +/// Operation +pub struct Operation<'item> { + table_name: &'item str, + number_of_columns: i32, + code: Action, + indirect: bool, +} + +impl Operation<'_> { + /// Returns the table name. + #[inline] + pub fn table_name(&self) -> &str { + self.table_name + } + + /// Returns the number of columns in table + #[inline] + pub fn number_of_columns(&self) -> i32 { + self.number_of_columns + } + + /// Returns the action code. + #[inline] + pub fn code(&self) -> Action { + self.code + } + + /// Returns `true` for an 'indirect' change. + #[inline] + pub fn indirect(&self) -> bool { + self.indirect + } +} + +impl Drop for ChangesetIter<'_> { + #[inline] + fn drop(&mut self) { + unsafe { + ffi::sqlite3changeset_finalize(self.it); + } + } +} + +/// An item passed to a conflict-handler by +/// [`Connection::apply`](crate::Connection::apply), or an item generated by +/// [`ChangesetIter::next`](ChangesetIter::next). +// TODO enum ? Delete, Insert, Update, ... +pub struct ChangesetItem { + it: *mut ffi::sqlite3_changeset_iter, +} + +impl ChangesetItem { + /// Obtain conflicting row values + /// + /// May only be called with an `SQLITE_CHANGESET_DATA` or + /// `SQLITE_CHANGESET_CONFLICT` conflict handler callback. + #[inline] + pub fn conflict(&self, col: usize) -> Result> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check(ffi::sqlite3changeset_conflict( + self.it, + col as i32, + &mut p_value, + ))?; + Ok(ValueRef::from_value(p_value)) + } + } + + /// Determine the number of foreign key constraint violations + /// + /// May only be called with an `SQLITE_CHANGESET_FOREIGN_KEY` conflict + /// handler callback. + #[inline] + pub fn fk_conflicts(&self) -> Result { + unsafe { + let mut p_out = 0; + check(ffi::sqlite3changeset_fk_conflicts(self.it, &mut p_out))?; + Ok(p_out) + } + } + + /// Obtain new.* Values + /// + /// May only be called if the type of change is either `SQLITE_UPDATE` or + /// `SQLITE_INSERT`. + #[inline] + pub fn new_value(&self, col: usize) -> Result> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check(ffi::sqlite3changeset_new(self.it, col as i32, &mut p_value))?; + Ok(ValueRef::from_value(p_value)) + } + } + + /// Obtain old.* Values + /// + /// May only be called if the type of change is either `SQLITE_DELETE` or + /// `SQLITE_UPDATE`. + #[inline] + pub fn old_value(&self, col: usize) -> Result> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check(ffi::sqlite3changeset_old(self.it, col as i32, &mut p_value))?; + Ok(ValueRef::from_value(p_value)) + } + } + + /// Obtain the current operation + #[inline] + pub fn op(&self) -> Result> { + let mut number_of_columns = 0; + let mut code = 0; + let mut indirect = 0; + let tab = unsafe { + let mut pz_tab: *const c_char = ptr::null(); + check(ffi::sqlite3changeset_op( + self.it, + &mut pz_tab, + &mut number_of_columns, + &mut code, + &mut indirect, + ))?; + CStr::from_ptr(pz_tab) + }; + let table_name = tab.to_str()?; + Ok(Operation { + table_name, + number_of_columns, + code: Action::from(code), + indirect: indirect != 0, + }) + } + + /// Obtain the primary key definition of a table + #[inline] + pub fn pk(&self) -> Result<&[u8]> { + let mut number_of_columns = 0; + unsafe { + let mut pks: *mut c_uchar = ptr::null_mut(); + check(ffi::sqlite3changeset_pk( + self.it, + &mut pks, + &mut number_of_columns, + ))?; + Ok(from_raw_parts(pks, number_of_columns as usize)) + } + } +} + +/// Used to combine two or more changesets or +/// patchsets +pub struct Changegroup { + cg: *mut ffi::sqlite3_changegroup, +} + +impl Changegroup { + /// Create a new change group. + #[inline] + pub fn new() -> Result { + let mut cg = ptr::null_mut(); + check(unsafe { ffi::sqlite3changegroup_new(&mut cg) })?; + Ok(Changegroup { cg }) + } + + /// Add a changeset + #[inline] + pub fn add(&mut self, cs: &Changeset) -> Result<()> { + check(unsafe { ffi::sqlite3changegroup_add(self.cg, cs.n, cs.cs) }) + } + + /// Add a changeset read from `input` to this change group. + #[inline] + pub fn add_stream(&mut self, input: &mut dyn Read) -> Result<()> { + let input_ref = &input; + check(unsafe { + ffi::sqlite3changegroup_add_strm( + self.cg, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + ) + }) + } + + /// Obtain a composite Changeset + #[inline] + pub fn output(&mut self) -> Result { + let mut n = 0; + let mut output: *mut c_void = ptr::null_mut(); + check(unsafe { ffi::sqlite3changegroup_output(self.cg, &mut n, &mut output) })?; + Ok(Changeset { cs: output, n }) + } + + /// Write the combined set of changes to `output`. + #[inline] + pub fn output_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check(unsafe { + ffi::sqlite3changegroup_output_strm( + self.cg, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) + } +} + +impl Drop for Changegroup { + #[inline] + fn drop(&mut self) { + unsafe { + ffi::sqlite3changegroup_delete(self.cg); + } + } +} + +impl Connection { + /// Apply a changeset to a database + pub fn apply(&self, cs: &Changeset, filter: Option, conflict: C) -> Result<()> + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + { + let db = self.db.borrow_mut().db; + + let filtered = filter.is_some(); + let tuple = &mut (filter, conflict); + check(unsafe { + if filtered { + ffi::sqlite3changeset_apply( + db, + cs.n, + cs.cs, + Some(call_filter::), + Some(call_conflict::), + tuple as *mut (Option, C) as *mut c_void, + ) + } else { + ffi::sqlite3changeset_apply( + db, + cs.n, + cs.cs, + None, + Some(call_conflict::), + tuple as *mut (Option, C) as *mut c_void, + ) + } + }) + } + + /// Apply a changeset to a database + pub fn apply_strm( + &self, + input: &mut dyn Read, + filter: Option, + conflict: C, + ) -> Result<()> + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + { + let input_ref = &input; + let db = self.db.borrow_mut().db; + + let filtered = filter.is_some(); + let tuple = &mut (filter, conflict); + check(unsafe { + if filtered { + ffi::sqlite3changeset_apply_strm( + db, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + Some(call_filter::), + Some(call_conflict::), + tuple as *mut (Option, C) as *mut c_void, + ) + } else { + ffi::sqlite3changeset_apply_strm( + db, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + None, + Some(call_conflict::), + tuple as *mut (Option, C) as *mut c_void, + ) + } + }) + } +} + +/// Constants passed to the conflict handler +/// See [here](https://sqlite.org/session.html#SQLITE_CHANGESET_CONFLICT) for details. +#[allow(missing_docs)] +#[repr(i32)] +#[derive(Debug, PartialEq)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms)] +pub enum ConflictType { + UNKNOWN = -1, + SQLITE_CHANGESET_DATA = ffi::SQLITE_CHANGESET_DATA, + SQLITE_CHANGESET_NOTFOUND = ffi::SQLITE_CHANGESET_NOTFOUND, + SQLITE_CHANGESET_CONFLICT = ffi::SQLITE_CHANGESET_CONFLICT, + SQLITE_CHANGESET_CONSTRAINT = ffi::SQLITE_CHANGESET_CONSTRAINT, + SQLITE_CHANGESET_FOREIGN_KEY = ffi::SQLITE_CHANGESET_FOREIGN_KEY, +} +impl From for ConflictType { + fn from(code: i32) -> ConflictType { + match code { + ffi::SQLITE_CHANGESET_DATA => ConflictType::SQLITE_CHANGESET_DATA, + ffi::SQLITE_CHANGESET_NOTFOUND => ConflictType::SQLITE_CHANGESET_NOTFOUND, + ffi::SQLITE_CHANGESET_CONFLICT => ConflictType::SQLITE_CHANGESET_CONFLICT, + ffi::SQLITE_CHANGESET_CONSTRAINT => ConflictType::SQLITE_CHANGESET_CONSTRAINT, + ffi::SQLITE_CHANGESET_FOREIGN_KEY => ConflictType::SQLITE_CHANGESET_FOREIGN_KEY, + _ => ConflictType::UNKNOWN, + } + } +} + +/// Constants returned by the conflict handler +/// See [here](https://sqlite.org/session.html#SQLITE_CHANGESET_ABORT) for details. +#[allow(missing_docs)] +#[repr(i32)] +#[derive(Debug, PartialEq)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms)] +pub enum ConflictAction { + SQLITE_CHANGESET_OMIT = ffi::SQLITE_CHANGESET_OMIT, + SQLITE_CHANGESET_REPLACE = ffi::SQLITE_CHANGESET_REPLACE, + SQLITE_CHANGESET_ABORT = ffi::SQLITE_CHANGESET_ABORT, +} + +unsafe extern "C" fn call_filter(p_ctx: *mut c_void, tbl_str: *const c_char) -> c_int +where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, +{ + use std::str; + + let tuple: *mut (Option, C) = p_ctx as *mut (Option, C); + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + match *tuple { + (Some(ref filter), _) => { + if let Ok(true) = catch_unwind(|| filter(tbl_name.expect("illegal table name"))) { + 1 + } else { + 0 + } + } + _ => unimplemented!(), + } +} + +unsafe extern "C" fn call_conflict( + p_ctx: *mut c_void, + e_conflict: c_int, + p: *mut ffi::sqlite3_changeset_iter, +) -> c_int +where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, +{ + let tuple: *mut (Option, C) = p_ctx as *mut (Option, C); + let conflict_type = ConflictType::from(e_conflict); + let item = ChangesetItem { it: p }; + if let Ok(action) = catch_unwind(|| (*tuple).1(conflict_type, item)) { + action as c_int + } else { + ffi::SQLITE_CHANGESET_ABORT + } +} + +unsafe extern "C" fn x_input(p_in: *mut c_void, data: *mut c_void, len: *mut c_int) -> c_int { + if p_in.is_null() { + return ffi::SQLITE_MISUSE; + } + let bytes: &mut [u8] = from_raw_parts_mut(data as *mut u8, *len as usize); + let input = p_in as *mut &mut dyn Read; + match (*input).read(bytes) { + Ok(n) => { + *len = n as i32; // TODO Validate: n = 0 may not mean the reader will always no longer be able to + // produce bytes. + ffi::SQLITE_OK + } + Err(_) => ffi::SQLITE_IOERR_READ, // TODO check if err is a (ru)sqlite Error => propagate + } +} + +unsafe extern "C" fn x_output(p_out: *mut c_void, data: *const c_void, len: c_int) -> c_int { + if p_out.is_null() { + return ffi::SQLITE_MISUSE; + } + // The sessions module never invokes an xOutput callback with the third + // parameter set to a value less than or equal to zero. + let bytes: &[u8] = from_raw_parts(data as *const u8, len as usize); + let output = p_out as *mut &mut dyn Write; + match (*output).write_all(bytes) { + Ok(_) => ffi::SQLITE_OK, + Err(_) => ffi::SQLITE_IOERR_WRITE, // TODO check if err is a (ru)sqlite Error => propagate + } +} + +#[cfg(test)] +mod test { + use fallible_streaming_iterator::FallibleStreamingIterator; + use std::io::Read; + use std::sync::atomic::{AtomicBool, Ordering}; + + use super::{Changeset, ChangesetIter, ConflictAction, ConflictType, Session}; + use crate::hooks::Action; + use crate::{Connection, Result}; + + fn one_changeset() -> Result { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + let mut session = Session::new(&db)?; + assert!(session.is_empty()); + + session.attach(None)?; + db.execute("INSERT INTO foo (t) VALUES (?);", ["bar"])?; + + session.changeset() + } + + fn one_changeset_strm() -> Result> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + let mut session = Session::new(&db)?; + assert!(session.is_empty()); + + session.attach(None)?; + db.execute("INSERT INTO foo (t) VALUES (?);", ["bar"])?; + + let mut output = Vec::new(); + session.changeset_strm(&mut output)?; + Ok(output) + } + + #[test] + fn test_changeset() -> Result<()> { + let changeset = one_changeset()?; + let mut iter = changeset.iter()?; + let item = iter.next()?; + assert!(item.is_some()); + + let item = item.unwrap(); + let op = item.op()?; + assert_eq!("foo", op.table_name()); + assert_eq!(1, op.number_of_columns()); + assert_eq!(Action::SQLITE_INSERT, op.code()); + assert!(!op.indirect()); + + let pk = item.pk()?; + assert_eq!(&[1], pk); + + let new_value = item.new_value(0)?; + assert_eq!(Ok("bar"), new_value.as_str()); + Ok(()) + } + + #[test] + fn test_changeset_strm() -> Result<()> { + let output = one_changeset_strm()?; + assert!(!output.is_empty()); + assert_eq!(14, output.len()); + + let input: &mut dyn Read = &mut output.as_slice(); + let mut iter = ChangesetIter::start_strm(&input)?; + let item = iter.next()?; + assert!(item.is_some()); + Ok(()) + } + + #[test] + fn test_changeset_apply() -> Result<()> { + let changeset = one_changeset()?; + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.apply( + &changeset, + None:: bool>, + |_conflict_type, _item| { + CALLED.store(true, Ordering::Relaxed); + ConflictAction::SQLITE_CHANGESET_OMIT + }, + )?; + + assert!(!CALLED.load(Ordering::Relaxed)); + let check = db.query_row("SELECT 1 FROM foo WHERE t = ?", ["bar"], |row| { + row.get::<_, i32>(0) + })?; + assert_eq!(1, check); + + // conflict expected when same changeset applied again on the same db + db.apply( + &changeset, + None:: bool>, + |conflict_type, item| { + CALLED.store(true, Ordering::Relaxed); + assert_eq!(ConflictType::SQLITE_CHANGESET_CONFLICT, conflict_type); + let conflict = item.conflict(0).unwrap(); + assert_eq!(Ok("bar"), conflict.as_str()); + ConflictAction::SQLITE_CHANGESET_OMIT + }, + )?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_changeset_apply_strm() -> Result<()> { + let output = one_changeset_strm()?; + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + let mut input = output.as_slice(); + db.apply_strm( + &mut input, + None:: bool>, + |_conflict_type, _item| ConflictAction::SQLITE_CHANGESET_OMIT, + )?; + + let check = db.query_row("SELECT 1 FROM foo WHERE t = ?", ["bar"], |row| { + row.get::<_, i32>(0) + })?; + assert_eq!(1, check); + Ok(()) + } + + #[test] + fn test_session_empty() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + let mut session = Session::new(&db)?; + assert!(session.is_empty()); + + session.attach(None)?; + db.execute("INSERT INTO foo (t) VALUES (?);", ["bar"])?; + + assert!(!session.is_empty()); + Ok(()) + } + + #[test] + fn test_session_set_enabled() -> Result<()> { + let db = Connection::open_in_memory()?; + + let mut session = Session::new(&db)?; + assert!(session.is_enabled()); + session.set_enabled(false); + assert!(!session.is_enabled()); + Ok(()) + } + + #[test] + fn test_session_set_indirect() -> Result<()> { + let db = Connection::open_in_memory()?; + + let mut session = Session::new(&db)?; + assert!(!session.is_indirect()); + session.set_indirect(true); + assert!(session.is_indirect()); + Ok(()) + } +} diff --git a/src/statement.rs b/src/statement.rs new file mode 100644 index 0000000..ee5e220 --- /dev/null +++ b/src/statement.rs @@ -0,0 +1,1555 @@ +use std::iter::IntoIterator; +use std::os::raw::{c_int, c_void}; +#[cfg(feature = "array")] +use std::rc::Rc; +use std::slice::from_raw_parts; +use std::{fmt, mem, ptr, str}; + +use super::ffi; +use super::{len_as_c_int, str_for_sqlite}; +use super::{ + AndThenRows, Connection, Error, MappedRows, Params, RawStatement, Result, Row, Rows, ValueRef, +}; +use crate::types::{ToSql, ToSqlOutput}; +#[cfg(feature = "array")] +use crate::vtab::array::{free_array, ARRAY_TYPE}; + +/// A prepared statement. +pub struct Statement<'conn> { + conn: &'conn Connection, + pub(crate) stmt: RawStatement, +} + +impl Statement<'_> { + /// Execute the prepared statement. + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// ## Example + /// + /// ### Use with positional parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, params}; + /// fn update_rows(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("UPDATE foo SET bar = ? WHERE qux = ?")?; + /// // For a single parameter, or a parameter where all the values have + /// // the same type, just passing an array is simplest. + /// stmt.execute([2i32])?; + /// // The `rusqlite::params!` macro is mostly useful when the parameters do not + /// // all have the same type, or if there are more than 32 parameters + /// // at once, but it can be used in other cases. + /// stmt.execute(params![1i32])?; + /// // However, it's not required, many cases are fine as: + /// stmt.execute(&[&2i32])?; + /// // Or even: + /// stmt.execute([2i32])?; + /// // If you really want to, this is an option as well. + /// stmt.execute((2i32,))?; + /// Ok(()) + /// } + /// ``` + /// + /// #### Heterogeneous positional parameters + /// + /// ``` + /// use rusqlite::{Connection, Result}; + /// fn store_file(conn: &Connection, path: &str, data: &[u8]) -> Result<()> { + /// # // no need to do it for real. + /// # fn sha256(_: &[u8]) -> [u8; 32] { [0; 32] } + /// let query = "INSERT OR REPLACE INTO files(path, hash, data) VALUES (?, ?, ?)"; + /// let mut stmt = conn.prepare_cached(query)?; + /// let hash: [u8; 32] = sha256(data); + /// // The easiest way to pass positional parameters of have several + /// // different types is by using a tuple. + /// stmt.execute((path, hash, data))?; + /// // Using the `params!` macro also works, and supports longer parameter lists: + /// stmt.execute(rusqlite::params![path, hash, data])?; + /// Ok(()) + /// } + /// # let c = Connection::open_in_memory().unwrap(); + /// # c.execute_batch("CREATE TABLE files(path TEXT PRIMARY KEY, hash BLOB, data BLOB)").unwrap(); + /// # store_file(&c, "foo/bar.txt", b"bibble").unwrap(); + /// # store_file(&c, "foo/baz.txt", b"bobble").unwrap(); + /// ``` + /// + /// ### Use with named parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, named_params}; + /// fn insert(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("INSERT INTO test (key, value) VALUES (:key, :value)")?; + /// // The `rusqlite::named_params!` macro (like `params!`) is useful for heterogeneous + /// // sets of parameters (where all parameters are not the same type), or for queries + /// // with many (more than 32) statically known parameters. + /// stmt.execute(named_params! { ":key": "one", ":val": 2 })?; + /// // However, named parameters can also be passed like: + /// stmt.execute(&[(":key", "three"), (":val", "four")])?; + /// // Or even: (note that a &T is required for the value type, currently) + /// stmt.execute(&[(":key", &100), (":val", &200)])?; + /// Ok(()) + /// } + /// ``` + /// + /// ### Use without parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, params}; + /// fn delete_all(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("DELETE FROM users")?; + /// stmt.execute([])?; + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails, the executed statement + /// returns rows (in which case `query` should be used instead), or the + /// underlying SQLite call fails. + #[inline] + pub fn execute(&mut self, params: P) -> Result { + params.__bind_in(self)?; + self.execute_with_bound_parameters() + } + + /// Execute the prepared statement with named parameter(s). + /// + /// Note: This function is deprecated in favor of [`Statement::execute`], + /// which can now take named parameters directly. + /// + /// If any parameters that were in the prepared statement are not included + /// in `params`, they will continue to use the most-recently bound value + /// from a previous call to `execute_named`, or `NULL` if they have never + /// been bound. + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails, the executed statement + /// returns rows (in which case `query` should be used instead), or the + /// underlying SQLite call fails. + #[doc(hidden)] + #[deprecated = "You can use `execute` with named params now."] + #[inline] + pub fn execute_named(&mut self, params: &[(&str, &dyn ToSql)]) -> Result { + self.execute(params) + } + + /// Execute an INSERT and return the ROWID. + /// + /// # Note + /// + /// This function is a convenience wrapper around + /// [`execute()`](Statement::execute) intended for queries that insert a + /// single item. It is possible to misuse this function in a way that it + /// cannot detect, such as by calling it on a statement which _updates_ + /// a single item rather than inserting one. Please don't do that. + /// + /// # Failure + /// + /// Will return `Err` if no row is inserted or many rows are inserted. + #[inline] + pub fn insert(&mut self, params: P) -> Result { + let changes = self.execute(params)?; + match changes { + 1 => Ok(self.conn.last_insert_rowid()), + _ => Err(Error::StatementChangedRows(changes)), + } + } + + /// Execute the prepared statement, returning a handle to the resulting + /// rows. + /// + /// Due to lifetime restrictions, the rows handle returned by `query` does + /// not implement the `Iterator` trait. Consider using + /// [`query_map`](Statement::query_map) or + /// [`query_and_then`](Statement::query_and_then) instead, which do. + /// + /// ## Example + /// + /// ### Use without parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result> { + /// let mut stmt = conn.prepare("SELECT name FROM people")?; + /// let mut rows = stmt.query([])?; + /// + /// let mut names = Vec::new(); + /// while let Some(row) = rows.next()? { + /// names.push(row.get(0)?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// + /// ### Use with positional parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection, name: &str) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = ?")?; + /// let mut rows = stmt.query(rusqlite::params![name])?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// Or, equivalently (but without the [`params!`] macro). + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection, name: &str) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = ?")?; + /// let mut rows = stmt.query([name])?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// ### Use with named parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = :name")?; + /// let mut rows = stmt.query(&[(":name", "one")])?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// Note, the `named_params!` macro is provided for syntactic convenience, + /// and so the above example could also be written as: + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, named_params}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = :name")?; + /// let mut rows = stmt.query(named_params! { ":name": "one" })?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + #[inline] + pub fn query(&mut self, params: P) -> Result> { + params.__bind_in(self)?; + Ok(Rows::new(self)) + } + + /// Execute the prepared statement with named parameter(s), returning a + /// handle for the resulting rows. + /// + /// Note: This function is deprecated in favor of [`Statement::query`], + /// which can now take named parameters directly. + /// + /// If any parameters that were in the prepared statement are not included + /// in `params`, they will continue to use the most-recently bound value + /// from a previous call to `query_named`, or `NULL` if they have never been + /// bound. + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails. + #[doc(hidden)] + #[deprecated = "You can use `query` with named params now."] + pub fn query_named(&mut self, params: &[(&str, &dyn ToSql)]) -> Result> { + self.query(params) + } + + /// Executes the prepared statement and maps a function over the resulting + /// rows, returning an iterator over the mapped function results. + /// + /// `f` is used to transform the _streaming_ iterator into a _standard_ + /// iterator. + /// + /// This is equivalent to `stmt.query(params)?.mapped(f)`. + /// + /// ## Example + /// + /// ### Use with positional params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result> { + /// let mut stmt = conn.prepare("SELECT name FROM people")?; + /// let rows = stmt.query_map([], |row| row.get(0))?; + /// + /// let mut names = Vec::new(); + /// for name_result in rows { + /// names.push(name_result?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// + /// ### Use with named params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result> { + /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = :id")?; + /// let rows = stmt.query_map(&[(":id", &"one")], |row| row.get(0))?; + /// + /// let mut names = Vec::new(); + /// for name_result in rows { + /// names.push(name_result?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + pub fn query_map(&mut self, params: P, f: F) -> Result> + where + P: Params, + F: FnMut(&Row<'_>) -> Result, + { + self.query(params).map(|rows| rows.mapped(f)) + } + + /// Execute the prepared statement with named parameter(s), returning an + /// iterator over the result of calling the mapping function over the + /// query's rows. + /// + /// Note: This function is deprecated in favor of [`Statement::query_map`], + /// which can now take named parameters directly. + /// + /// If any parameters that were in the prepared statement + /// are not included in `params`, they will continue to use the + /// most-recently bound value from a previous call to `query_named`, + /// or `NULL` if they have never been bound. + /// + /// `f` is used to transform the _streaming_ iterator into a _standard_ + /// iterator. + /// + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + #[doc(hidden)] + #[deprecated = "You can use `query_map` with named params now."] + pub fn query_map_named( + &mut self, + params: &[(&str, &dyn ToSql)], + f: F, + ) -> Result> + where + F: FnMut(&Row<'_>) -> Result, + { + self.query_map(params, f) + } + + /// Executes the prepared statement and maps a function over the resulting + /// rows, where the function returns a `Result` with `Error` type + /// implementing `std::convert::From` (so errors can be unified). + /// + /// This is equivalent to `stmt.query(params)?.and_then(f)`. + /// + /// ## Example + /// + /// ### Use with named params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// struct Person { + /// name: String, + /// }; + /// + /// fn name_to_person(name: String) -> Result { + /// // ... check for valid name + /// Ok(Person { name }) + /// } + /// + /// fn get_names(conn: &Connection) -> Result> { + /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = :id")?; + /// let rows = stmt.query_and_then(&[(":id", "one")], |row| name_to_person(row.get(0)?))?; + /// + /// let mut persons = Vec::new(); + /// for person_result in rows { + /// persons.push(person_result?); + /// } + /// + /// Ok(persons) + /// } + /// ``` + /// + /// ### Use with positional params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result> { + /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = ?")?; + /// let rows = stmt.query_and_then(["one"], |row| row.get::<_, String>(0))?; + /// + /// let mut persons = Vec::new(); + /// for person_result in rows { + /// persons.push(person_result?); + /// } + /// + /// Ok(persons) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails. + #[inline] + pub fn query_and_then(&mut self, params: P, f: F) -> Result> + where + P: Params, + E: From, + F: FnMut(&Row<'_>) -> Result, + { + self.query(params).map(|rows| rows.and_then(f)) + } + + /// Execute the prepared statement with named parameter(s), returning an + /// iterator over the result of calling the mapping function over the + /// query's rows. + /// + /// Note: This function is deprecated in favor of + /// [`Statement::query_and_then`], which can now take named parameters + /// directly. + /// + /// If any parameters that were in the prepared statement are not included + /// in `params`, they will continue to use the most-recently bound value + /// from a previous call to `query_named`, or `NULL` if they have never been + /// bound. + /// + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + #[doc(hidden)] + #[deprecated = "You can use `query_and_then` with named params now."] + pub fn query_and_then_named( + &mut self, + params: &[(&str, &dyn ToSql)], + f: F, + ) -> Result> + where + E: From, + F: FnMut(&Row<'_>) -> Result, + { + self.query_and_then(params, f) + } + + /// Return `true` if a query in the SQL statement it executes returns one + /// or more rows and `false` if the SQL returns an empty set. + #[inline] + pub fn exists(&mut self, params: P) -> Result { + let mut rows = self.query(params)?; + let exists = rows.next()?.is_some(); + Ok(exists) + } + + /// Convenience method to execute a query that is expected to return a + /// single row. + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call + /// [`.optional()`](crate::OptionalExtension::optional) on the result of + /// this to get a `Result>` (requires that the trait + /// `rusqlite::OptionalExtension` is imported). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn query_row(&mut self, params: P, f: F) -> Result + where + P: Params, + F: FnOnce(&Row<'_>) -> Result, + { + let mut rows = self.query(params)?; + + rows.get_expected_row().and_then(f) + } + + /// Convenience method to execute a query with named parameter(s) that is + /// expected to return a single row. + /// + /// Note: This function is deprecated in favor of + /// [`Statement::query_and_then`], which can now take named parameters + /// directly. + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call + /// [`.optional()`](crate::OptionalExtension::optional) on the result of + /// this to get a `Result>` (requires that the trait + /// `rusqlite::OptionalExtension` is imported). + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[doc(hidden)] + #[deprecated = "You can use `query_row` with named params now."] + pub fn query_row_named(&mut self, params: &[(&str, &dyn ToSql)], f: F) -> Result + where + F: FnOnce(&Row<'_>) -> Result, + { + self.query_row(params, f) + } + + /// Consumes the statement. + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn finalize(mut self) -> Result<()> { + self.finalize_() + } + + /// Return the (one-based) index of an SQL parameter given its name. + /// + /// Note that the initial ":" or "$" or "@" or "?" used to specify the + /// parameter is included as part of the name. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn example(conn: &Connection) -> Result<()> { + /// let stmt = conn.prepare("SELECT * FROM test WHERE name = :example")?; + /// let index = stmt.parameter_index(":example")?; + /// assert_eq!(index, Some(1)); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return Err if `name` is invalid. Will return Ok(None) if the name + /// is valid but not a bound parameter of this statement. + #[inline] + pub fn parameter_index(&self, name: &str) -> Result> { + Ok(self.stmt.bind_parameter_index(name)) + } + + /// Return the SQL parameter name given its (one-based) index (the inverse + /// of [`Statement::parameter_index`]). + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn example(conn: &Connection) -> Result<()> { + /// let stmt = conn.prepare("SELECT * FROM test WHERE name = :example")?; + /// let index = stmt.parameter_name(1); + /// assert_eq!(index, Some(":example")); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `None` if the column index is out of bounds or if the + /// parameter is positional. + #[inline] + pub fn parameter_name(&self, index: usize) -> Option<&'_ str> { + self.stmt.bind_parameter_name(index as i32).map(|name| { + str::from_utf8(name.to_bytes()).expect("Invalid UTF-8 sequence in parameter name") + }) + } + + #[inline] + pub(crate) fn bind_parameters

(&mut self, params: P) -> Result<()> + where + P: IntoIterator, + P::Item: ToSql, + { + let expected = self.stmt.bind_parameter_count(); + let mut index = 0; + for p in params.into_iter() { + index += 1; // The leftmost SQL parameter has an index of 1. + if index > expected { + break; + } + self.bind_parameter(&p, index)?; + } + if index != expected { + Err(Error::InvalidParameterCount(index, expected)) + } else { + Ok(()) + } + } + + #[inline] + pub(crate) fn ensure_parameter_count(&self, n: usize) -> Result<()> { + let count = self.parameter_count(); + if count != n { + Err(Error::InvalidParameterCount(n, count)) + } else { + Ok(()) + } + } + + #[inline] + pub(crate) fn bind_parameters_named( + &mut self, + params: &[(&str, &T)], + ) -> Result<()> { + for &(name, value) in params { + if let Some(i) = self.parameter_index(name)? { + let ts: &dyn ToSql = &value; + self.bind_parameter(ts, i)?; + } else { + return Err(Error::InvalidParameterName(name.into())); + } + } + Ok(()) + } + + /// Return the number of parameters that can be bound to this statement. + #[inline] + pub fn parameter_count(&self) -> usize { + self.stmt.bind_parameter_count() + } + + /// Low level API to directly bind a parameter to a given index. + /// + /// Note that the index is one-based, that is, the first parameter index is + /// 1 and not 0. This is consistent with the SQLite API and the values given + /// to parameters bound as `?NNN`. + /// + /// The valid values for `one_based_col_index` begin at `1`, and end at + /// [`Statement::parameter_count`], inclusive. + /// + /// # Caveats + /// + /// This should not generally be used, but is available for special cases + /// such as: + /// + /// - binding parameters where a gap exists. + /// - binding named and positional parameters in the same query. + /// - separating parameter binding from query execution. + /// + /// In general, statements that have had *any* parameters bound this way + /// should have *all* parameters bound this way, and be queried or executed + /// by [`Statement::raw_query`] or [`Statement::raw_execute`], other usage + /// is unsupported and will likely, probably in surprising ways. + /// + /// That is: Do not mix the "raw" statement functions with the rest of the + /// API, or the results may be surprising, and may even change in future + /// versions without comment. + /// + /// # Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test WHERE name = :name AND value > ?2")?; + /// let name_index = stmt.parameter_index(":name")?.expect("No such parameter"); + /// stmt.raw_bind_parameter(name_index, "foo")?; + /// stmt.raw_bind_parameter(2, 100)?; + /// let mut rows = stmt.raw_query(); + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + #[inline] + pub fn raw_bind_parameter( + &mut self, + one_based_col_index: usize, + param: T, + ) -> Result<()> { + // This is the same as `bind_parameter` but slightly more ergonomic and + // correctly takes `&mut self`. + self.bind_parameter(¶m, one_based_col_index) + } + + /// Low level API to execute a statement given that all parameters were + /// bound explicitly with the [`Statement::raw_bind_parameter`] API. + /// + /// # Caveats + /// + /// Any unbound parameters will have `NULL` as their value. + /// + /// This should not generally be used outside of special cases, and + /// functions in the [`Statement::execute`] family should be preferred. + /// + /// # Failure + /// + /// Will return `Err` if the executed statement returns rows (in which case + /// `query` should be used instead), or the underlying SQLite call fails. + #[inline] + pub fn raw_execute(&mut self) -> Result { + self.execute_with_bound_parameters() + } + + /// Low level API to get `Rows` for this query given that all parameters + /// were bound explicitly with the [`Statement::raw_bind_parameter`] API. + /// + /// # Caveats + /// + /// Any unbound parameters will have `NULL` as their value. + /// + /// This should not generally be used outside of special cases, and + /// functions in the [`Statement::query`] family should be preferred. + /// + /// Note that if the SQL does not return results, [`Statement::raw_execute`] + /// should be used instead. + #[inline] + pub fn raw_query(&mut self) -> Rows<'_> { + Rows::new(self) + } + + // generic because many of these branches can constant fold away. + fn bind_parameter(&self, param: &P, col: usize) -> Result<()> { + let value = param.to_sql()?; + + let ptr = unsafe { self.stmt.ptr() }; + let value = match value { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(len) => { + // TODO sqlite3_bind_zeroblob64 // 3.8.11 + return self + .conn + .decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, len) }); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(a) => { + return self.conn.decode_result(unsafe { + ffi::sqlite3_bind_pointer( + ptr, + col as c_int, + Rc::into_raw(a) as *mut c_void, + ARRAY_TYPE, + Some(free_array), + ) + }); + } + }; + self.conn.decode_result(match value { + ValueRef::Null => unsafe { ffi::sqlite3_bind_null(ptr, col as c_int) }, + ValueRef::Integer(i) => unsafe { ffi::sqlite3_bind_int64(ptr, col as c_int, i) }, + ValueRef::Real(r) => unsafe { ffi::sqlite3_bind_double(ptr, col as c_int, r) }, + ValueRef::Text(s) => unsafe { + let (c_str, len, destructor) = str_for_sqlite(s)?; + // TODO sqlite3_bind_text64 // 3.8.7 + ffi::sqlite3_bind_text(ptr, col as c_int, c_str, len, destructor) + }, + ValueRef::Blob(b) => unsafe { + let length = len_as_c_int(b.len())?; + if length == 0 { + ffi::sqlite3_bind_zeroblob(ptr, col as c_int, 0) + } else { + // TODO sqlite3_bind_blob64 // 3.8.7 + ffi::sqlite3_bind_blob( + ptr, + col as c_int, + b.as_ptr().cast::(), + length, + ffi::SQLITE_TRANSIENT(), + ) + } + }, + }) + } + + #[inline] + fn execute_with_bound_parameters(&mut self) -> Result { + self.check_update()?; + let r = self.stmt.step(); + self.stmt.reset(); + match r { + ffi::SQLITE_DONE => Ok(self.conn.changes() as usize), + ffi::SQLITE_ROW => Err(Error::ExecuteReturnedResults), + _ => Err(self.conn.decode_result(r).unwrap_err()), + } + } + + #[inline] + fn finalize_(&mut self) -> Result<()> { + let mut stmt = unsafe { RawStatement::new(ptr::null_mut(), 0) }; + mem::swap(&mut stmt, &mut self.stmt); + self.conn.decode_result(stmt.finalize()) + } + + #[cfg(all(feature = "modern_sqlite", feature = "extra_check"))] + #[inline] + fn check_update(&self) -> Result<()> { + // sqlite3_column_count works for DML but not for DDL (ie ALTER) + if self.column_count() > 0 && self.stmt.readonly() { + return Err(Error::ExecuteReturnedResults); + } + Ok(()) + } + + #[cfg(all(not(feature = "modern_sqlite"), feature = "extra_check"))] + #[inline] + fn check_update(&self) -> Result<()> { + // sqlite3_column_count works for DML but not for DDL (ie ALTER) + if self.column_count() > 0 { + return Err(Error::ExecuteReturnedResults); + } + Ok(()) + } + + #[cfg(not(feature = "extra_check"))] + #[inline] + #[allow(clippy::unnecessary_wraps)] + fn check_update(&self) -> Result<()> { + Ok(()) + } + + /// Returns a string containing the SQL text of prepared statement with + /// bound parameters expanded. + #[cfg(feature = "modern_sqlite")] + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn expanded_sql(&self) -> Option { + self.stmt + .expanded_sql() + .map(|s| s.to_string_lossy().to_string()) + } + + /// Get the value for one of the status counters for this statement. + #[inline] + pub fn get_status(&self, status: StatementStatus) -> i32 { + self.stmt.get_status(status, false) + } + + /// Reset the value of one of the status counters for this statement, + #[inline] + /// returning the value it had before resetting. + pub fn reset_status(&self, status: StatementStatus) -> i32 { + self.stmt.get_status(status, true) + } + + /// Returns 1 if the prepared statement is an EXPLAIN statement, + /// or 2 if the statement is an EXPLAIN QUERY PLAN, + /// or 0 if it is an ordinary statement or a NULL pointer. + #[inline] + #[cfg(feature = "modern_sqlite")] // 3.28.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn is_explain(&self) -> i32 { + self.stmt.is_explain() + } + + #[cfg(feature = "extra_check")] + #[inline] + pub(crate) fn check_no_tail(&self) -> Result<()> { + if self.stmt.has_tail() { + Err(Error::MultipleStatement) + } else { + Ok(()) + } + } + + #[cfg(not(feature = "extra_check"))] + #[inline] + #[allow(clippy::unnecessary_wraps)] + pub(crate) fn check_no_tail(&self) -> Result<()> { + Ok(()) + } + + /// Safety: This is unsafe, because using `sqlite3_stmt` after the + /// connection has closed is illegal, but `RawStatement` does not enforce + /// this, as it loses our protective `'conn` lifetime bound. + #[inline] + pub(crate) unsafe fn into_raw(mut self) -> RawStatement { + let mut stmt = RawStatement::new(ptr::null_mut(), 0); + mem::swap(&mut stmt, &mut self.stmt); + stmt + } +} + +impl fmt::Debug for Statement<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let sql = if self.stmt.is_null() { + Ok("") + } else { + str::from_utf8(self.stmt.sql().unwrap().to_bytes()) + }; + f.debug_struct("Statement") + .field("conn", self.conn) + .field("stmt", &self.stmt) + .field("sql", &sql) + .finish() + } +} + +impl Drop for Statement<'_> { + #[allow(unused_must_use)] + #[inline] + fn drop(&mut self) { + self.finalize_(); + } +} + +impl Statement<'_> { + #[inline] + pub(super) fn new(conn: &Connection, stmt: RawStatement) -> Statement<'_> { + Statement { conn, stmt } + } + + pub(super) fn value_ref(&self, col: usize) -> ValueRef<'_> { + let raw = unsafe { self.stmt.ptr() }; + + match self.stmt.column_type(col) { + ffi::SQLITE_NULL => ValueRef::Null, + ffi::SQLITE_INTEGER => { + ValueRef::Integer(unsafe { ffi::sqlite3_column_int64(raw, col as c_int) }) + } + ffi::SQLITE_FLOAT => { + ValueRef::Real(unsafe { ffi::sqlite3_column_double(raw, col as c_int) }) + } + ffi::SQLITE_TEXT => { + let s = unsafe { + // Quoting from "Using SQLite" book: + // To avoid problems, an application should first extract the desired type using + // a sqlite3_column_xxx() function, and then call the + // appropriate sqlite3_column_bytes() function. + let text = ffi::sqlite3_column_text(raw, col as c_int); + let len = ffi::sqlite3_column_bytes(raw, col as c_int); + assert!( + !text.is_null(), + "unexpected SQLITE_TEXT column type with NULL data" + ); + from_raw_parts(text.cast::(), len as usize) + }; + + ValueRef::Text(s) + } + ffi::SQLITE_BLOB => { + let (blob, len) = unsafe { + ( + ffi::sqlite3_column_blob(raw, col as c_int), + ffi::sqlite3_column_bytes(raw, col as c_int), + ) + }; + + assert!( + len >= 0, + "unexpected negative return from sqlite3_column_bytes" + ); + if len > 0 { + assert!( + !blob.is_null(), + "unexpected SQLITE_BLOB column type with NULL data" + ); + ValueRef::Blob(unsafe { from_raw_parts(blob.cast::(), len as usize) }) + } else { + // The return value from sqlite3_column_blob() for a zero-length BLOB + // is a NULL pointer. + ValueRef::Blob(&[]) + } + } + _ => unreachable!("sqlite3_column_type returned invalid value"), + } + } + + #[inline] + pub(super) fn step(&self) -> Result { + match self.stmt.step() { + ffi::SQLITE_ROW => Ok(true), + ffi::SQLITE_DONE => Ok(false), + code => Err(self.conn.decode_result(code).unwrap_err()), + } + } + + #[inline] + pub(super) fn reset(&self) -> c_int { + self.stmt.reset() + } +} + +/// Prepared statement status counters. +/// +/// See `https://www.sqlite.org/c3ref/c_stmtstatus_counter.html` +/// for explanations of each. +/// +/// Note that depending on your version of SQLite, all of these +/// may not be available. +#[repr(i32)] +#[derive(Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum StatementStatus { + /// Equivalent to SQLITE_STMTSTATUS_FULLSCAN_STEP + FullscanStep = 1, + /// Equivalent to SQLITE_STMTSTATUS_SORT + Sort = 2, + /// Equivalent to SQLITE_STMTSTATUS_AUTOINDEX + AutoIndex = 3, + /// Equivalent to SQLITE_STMTSTATUS_VM_STEP + VmStep = 4, + /// Equivalent to SQLITE_STMTSTATUS_REPREPARE (3.20.0) + RePrepare = 5, + /// Equivalent to SQLITE_STMTSTATUS_RUN (3.20.0) + Run = 6, + /// Equivalent to SQLITE_STMTSTATUS_FILTER_MISS + FilterMiss = 7, + /// Equivalent to SQLITE_STMTSTATUS_FILTER_HIT + FilterHit = 8, + /// Equivalent to SQLITE_STMTSTATUS_MEMUSED (3.20.0) + MemUsed = 99, +} + +#[cfg(test)] +mod test { + use crate::types::ToSql; + use crate::{params_from_iter, Connection, Error, Result}; + + #[test] + #[allow(deprecated)] + fn test_execute_named() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER)")?; + + assert_eq!( + db.execute_named("INSERT INTO foo(x) VALUES (:x)", &[(":x", &1i32)])?, + 1 + ); + assert_eq!( + db.execute("INSERT INTO foo(x) VALUES (:x)", &[(":x", &2i32)])?, + 1 + ); + assert_eq!( + db.execute( + "INSERT INTO foo(x) VALUES (:x)", + crate::named_params! {":x": 3i32} + )?, + 1 + ); + + assert_eq!( + 6i32, + db.query_row_named::( + "SELECT SUM(x) FROM foo WHERE x > :x", + &[(":x", &0i32)], + |r| r.get(0) + )? + ); + assert_eq!( + 5i32, + db.query_row::( + "SELECT SUM(x) FROM foo WHERE x > :x", + &[(":x", &1i32)], + |r| r.get(0) + )? + ); + Ok(()) + } + + #[test] + #[allow(deprecated)] + fn test_stmt_execute_named() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag \ + INTEGER)"; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("INSERT INTO test (name) VALUES (:name)")?; + stmt.execute_named(&[(":name", &"one")])?; + + let mut stmt = db.prepare("SELECT COUNT(*) FROM test WHERE name = :name")?; + assert_eq!( + 1i32, + stmt.query_row_named::(&[(":name", &"one")], |r| r.get(0))? + ); + assert_eq!( + 1i32, + stmt.query_row::(&[(":name", "one")], |r| r.get(0))? + ); + Ok(()) + } + + #[test] + #[allow(deprecated)] + fn test_query_named() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + "#; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("SELECT id FROM test where name = :name")?; + // legacy `_named` api + { + let mut rows = stmt.query_named(&[(":name", &"one")])?; + let id: Result = rows.next()?.unwrap().get(0); + assert_eq!(Ok(1), id); + } + + // plain api + { + let mut rows = stmt.query(&[(":name", "one")])?; + let id: Result = rows.next()?.unwrap().get(0); + assert_eq!(Ok(1), id); + } + Ok(()) + } + + #[test] + #[allow(deprecated)] + fn test_query_map_named() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + "#; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("SELECT id FROM test where name = :name")?; + // legacy `_named` api + { + let mut rows = stmt.query_map_named(&[(":name", &"one")], |row| { + let id: Result = row.get(0); + id.map(|i| 2 * i) + })?; + + let doubled_id: i32 = rows.next().unwrap()?; + assert_eq!(2, doubled_id); + } + // plain api + { + let mut rows = stmt.query_map(&[(":name", "one")], |row| { + let id: Result = row.get(0); + id.map(|i| 2 * i) + })?; + + let doubled_id: i32 = rows.next().unwrap()?; + assert_eq!(2, doubled_id); + } + Ok(()) + } + + #[test] + #[allow(deprecated)] + fn test_query_and_then_named() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + INSERT INTO test(id, name) VALUES (2, "one"); + "#; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("SELECT id FROM test where name = :name ORDER BY id ASC")?; + let mut rows = stmt.query_and_then_named(&[(":name", &"one")], |row| { + let id: i32 = row.get(0)?; + if id == 1 { + Ok(id) + } else { + Err(Error::SqliteSingleThreadedMode) + } + })?; + + // first row should be Ok + let doubled_id: i32 = rows.next().unwrap()?; + assert_eq!(1, doubled_id); + + // second row should be Err + #[allow(clippy::match_wild_err_arm)] + match rows.next().unwrap() { + Ok(_) => panic!("invalid Ok"), + Err(Error::SqliteSingleThreadedMode) => (), + Err(_) => panic!("invalid Err"), + } + Ok(()) + } + + #[test] + fn test_query_and_then_by_name() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + INSERT INTO test(id, name) VALUES (2, "one"); + "#; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("SELECT id FROM test where name = :name ORDER BY id ASC")?; + let mut rows = stmt.query_and_then(&[(":name", "one")], |row| { + let id: i32 = row.get(0)?; + if id == 1 { + Ok(id) + } else { + Err(Error::SqliteSingleThreadedMode) + } + })?; + + // first row should be Ok + let doubled_id: i32 = rows.next().unwrap()?; + assert_eq!(1, doubled_id); + + // second row should be Err + #[allow(clippy::match_wild_err_arm)] + match rows.next().unwrap() { + Ok(_) => panic!("invalid Ok"), + Err(Error::SqliteSingleThreadedMode) => (), + Err(_) => panic!("invalid Err"), + } + Ok(()) + } + + #[test] + #[allow(deprecated)] + fn test_unbound_parameters_are_null() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "CREATE TABLE test (x TEXT, y TEXT)"; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("INSERT INTO test (x, y) VALUES (:x, :y)")?; + stmt.execute_named(&[(":x", &"one")])?; + + let result: Option = + db.query_row("SELECT y FROM test WHERE x = 'one'", [], |row| row.get(0))?; + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn test_raw_binding() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE test (name TEXT, value INTEGER)")?; + { + let mut stmt = db.prepare("INSERT INTO test (name, value) VALUES (:name, ?3)")?; + + let name_idx = stmt.parameter_index(":name")?.unwrap(); + stmt.raw_bind_parameter(name_idx, "example")?; + stmt.raw_bind_parameter(3, 50i32)?; + let n = stmt.raw_execute()?; + assert_eq!(n, 1); + } + + { + let mut stmt = db.prepare("SELECT name, value FROM test WHERE value = ?2")?; + stmt.raw_bind_parameter(2, 50)?; + let mut rows = stmt.raw_query(); + { + let row = rows.next()?.unwrap(); + let name: String = row.get(0)?; + assert_eq!(name, "example"); + let value: i32 = row.get(1)?; + assert_eq!(value, 50); + } + assert!(rows.next()?.is_none()); + } + + Ok(()) + } + + #[test] + fn test_unbound_parameters_are_reused() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "CREATE TABLE test (x TEXT, y TEXT)"; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("INSERT INTO test (x, y) VALUES (:x, :y)")?; + stmt.execute(&[(":x", "one")])?; + stmt.execute(&[(":y", "two")])?; + + let result: String = + db.query_row("SELECT x FROM test WHERE y = 'two'", [], |row| row.get(0))?; + assert_eq!(result, "one"); + Ok(()) + } + + #[test] + fn test_insert() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER UNIQUE)")?; + let mut stmt = db.prepare("INSERT OR IGNORE INTO foo (x) VALUES (?)")?; + assert_eq!(stmt.insert([1i32])?, 1); + assert_eq!(stmt.insert([2i32])?, 2); + match stmt.insert([1i32]).unwrap_err() { + Error::StatementChangedRows(0) => (), + err => panic!("Unexpected error {}", err), + } + let mut multi = db.prepare("INSERT INTO foo (x) SELECT 3 UNION ALL SELECT 4")?; + match multi.insert([]).unwrap_err() { + Error::StatementChangedRows(2) => (), + err => panic!("Unexpected error {}", err), + } + Ok(()) + } + + #[test] + fn test_insert_different_tables() -> Result<()> { + // Test for https://github.com/rusqlite/rusqlite/issues/171 + let db = Connection::open_in_memory()?; + db.execute_batch( + r" + CREATE TABLE foo(x INTEGER); + CREATE TABLE bar(x INTEGER); + ", + )?; + + assert_eq!(db.prepare("INSERT INTO foo VALUES (10)")?.insert([])?, 1); + assert_eq!(db.prepare("INSERT INTO bar VALUES (10)")?.insert([])?, 1); + Ok(()) + } + + #[test] + fn test_exists() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT 1 FROM foo WHERE x = ?")?; + assert!(stmt.exists([1i32])?); + assert!(stmt.exists([2i32])?); + assert!(!stmt.exists([0i32])?); + Ok(()) + } + #[test] + fn test_tuple_params() -> Result<()> { + let db = Connection::open_in_memory()?; + let s = db.query_row("SELECT printf('[%s]', ?)", ("abc",), |r| { + r.get::<_, String>(0) + })?; + assert_eq!(s, "[abc]"); + let s = db.query_row( + "SELECT printf('%d %s %d', ?, ?, ?)", + (1i32, "abc", 2i32), + |r| r.get::<_, String>(0), + )?; + assert_eq!(s, "1 abc 2"); + let s = db.query_row( + "SELECT printf('%d %s %d %d', ?, ?, ?, ?)", + (1, "abc", 2i32, 4i64), + |r| r.get::<_, String>(0), + )?; + assert_eq!(s, "1 abc 2 4"); + #[rustfmt::skip] + let bigtup = ( + 0, "a", 1, "b", 2, "c", 3, "d", + 4, "e", 5, "f", 6, "g", 7, "h", + ); + let query = "SELECT printf( + '%d %s | %d %s | %d %s | %d %s || %d %s | %d %s | %d %s | %d %s', + ?, ?, ?, ?, + ?, ?, ?, ?, + ?, ?, ?, ?, + ?, ?, ?, ? + )"; + let s = db.query_row(query, bigtup, |r| r.get::<_, String>(0))?; + assert_eq!(s, "0 a | 1 b | 2 c | 3 d || 4 e | 5 f | 6 g | 7 h"); + Ok(()) + } + + #[test] + fn test_query_row() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + INSERT INTO foo VALUES(2, 4); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT y FROM foo WHERE x = ?")?; + let y: Result = stmt.query_row([1i32], |r| r.get(0)); + assert_eq!(3i64, y?); + Ok(()) + } + + #[test] + fn test_query_by_column_name() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT y FROM foo")?; + let y: Result = stmt.query_row([], |r| r.get("y")); + assert_eq!(3i64, y?); + Ok(()) + } + + #[test] + fn test_query_by_column_name_ignore_case() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT y as Y FROM foo")?; + let y: Result = stmt.query_row([], |r| r.get("y")); + assert_eq!(3i64, y?); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_expanded_sql() -> Result<()> { + let db = Connection::open_in_memory()?; + let stmt = db.prepare("SELECT ?")?; + stmt.bind_parameter(&1, 1)?; + assert_eq!(Some("SELECT 1".to_owned()), stmt.expanded_sql()); + Ok(()) + } + + #[test] + fn test_bind_parameters() -> Result<()> { + let db = Connection::open_in_memory()?; + // dynamic slice: + db.query_row( + "SELECT ?1, ?2, ?3", + &[&1u8 as &dyn ToSql, &"one", &Some("one")], + |row| row.get::<_, u8>(0), + )?; + // existing collection: + let data = vec![1, 2, 3]; + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| { + row.get::<_, u8>(0) + })?; + db.query_row( + "SELECT ?1, ?2, ?3", + params_from_iter(data.as_slice()), + |row| row.get::<_, u8>(0), + )?; + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data), |row| { + row.get::<_, u8>(0) + })?; + + use std::collections::BTreeSet; + let data: BTreeSet = ["one", "two", "three"] + .iter() + .map(|s| (*s).to_string()) + .collect(); + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| { + row.get::<_, String>(0) + })?; + + let data = [0; 3]; + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| { + row.get::<_, u8>(0) + })?; + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data.iter()), |row| { + row.get::<_, u8>(0) + })?; + Ok(()) + } + + #[test] + fn test_parameter_name() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE test (name TEXT, value INTEGER)")?; + let stmt = db.prepare("INSERT INTO test (name, value) VALUES (:name, ?3)")?; + assert_eq!(stmt.parameter_name(0), None); + assert_eq!(stmt.parameter_name(1), Some(":name")); + assert_eq!(stmt.parameter_name(2), None); + Ok(()) + } + + #[test] + fn test_empty_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + let mut stmt = conn.prepare("")?; + assert_eq!(0, stmt.column_count()); + assert!(stmt.parameter_index("test").is_ok()); + assert!(stmt.step().is_err()); + stmt.reset(); + assert!(stmt.execute([]).is_err()); + Ok(()) + } + + #[test] + fn test_comment_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + conn.prepare("/*SELECT 1;*/")?; + Ok(()) + } + + #[test] + fn test_comment_and_sql_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + let stmt = conn.prepare("/*...*/ SELECT 1;")?; + assert_eq!(1, stmt.column_count()); + Ok(()) + } + + #[test] + fn test_semi_colon_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + let stmt = conn.prepare(";")?; + assert_eq!(0, stmt.column_count()); + Ok(()) + } + + #[test] + fn test_utf16_conversion() -> Result<()> { + let db = Connection::open_in_memory()?; + db.pragma_update(None, "encoding", &"UTF-16le")?; + let encoding: String = db.pragma_query_value(None, "encoding", |row| row.get(0))?; + assert_eq!("UTF-16le", encoding); + db.execute_batch("CREATE TABLE foo(x TEXT)")?; + let expected = "テスト"; + db.execute("INSERT INTO foo(x) VALUES (?)", &[&expected])?; + let actual: String = db.query_row("SELECT x FROM foo", [], |row| row.get(0))?; + assert_eq!(expected, actual); + Ok(()) + } + + #[test] + fn test_nul_byte() -> Result<()> { + let db = Connection::open_in_memory()?; + let expected = "a\x00b"; + let actual: String = db.query_row("SELECT ?", [expected], |row| row.get(0))?; + assert_eq!(expected, actual); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn is_explain() -> Result<()> { + let db = Connection::open_in_memory()?; + let stmt = db.prepare("SELECT 1;")?; + assert_eq!(0, stmt.is_explain()); + Ok(()) + } + + #[test] + #[cfg(all(feature = "modern_sqlite", not(feature = "bundled-sqlcipher")))] // SQLite >= 3.38.0 + fn test_error_offset() -> Result<()> { + use crate::ffi::ErrorCode; + let db = Connection::open_in_memory()?; + let r = db.execute_batch("SELECT CURRENT_TIMESTANP;"); + assert!(r.is_err()); + match r.unwrap_err() { + Error::SqlInputError { error, offset, .. } => { + assert_eq!(error.code, ErrorCode::Unknown); + assert_eq!(offset, 7); + } + err => panic!("Unexpected error {}", err), + } + Ok(()) + } +} diff --git a/src/trace.rs b/src/trace.rs new file mode 100644 index 0000000..7fc9090 --- /dev/null +++ b/src/trace.rs @@ -0,0 +1,184 @@ +//! Tracing and profiling functions. Error and warning log. + +use std::ffi::{CStr, CString}; +use std::mem; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; +use std::time::Duration; + +use super::ffi; +use crate::error::error_from_sqlite_code; +use crate::{Connection, Result}; + +/// Set up the process-wide SQLite error logging callback. +/// +/// # Safety +/// +/// This function is marked unsafe for two reasons: +/// +/// * The function is not threadsafe. No other SQLite calls may be made while +/// `config_log` is running, and multiple threads may not call `config_log` +/// simultaneously. +/// * The provided `callback` itself function has two requirements: +/// * It must not invoke any SQLite calls. +/// * It must be threadsafe if SQLite is used in a multithreaded way. +/// +/// cf [The Error And Warning Log](http://sqlite.org/errlog.html). +pub unsafe fn config_log(callback: Option) -> Result<()> { + extern "C" fn log_callback(p_arg: *mut c_void, err: c_int, msg: *const c_char) { + let c_slice = unsafe { CStr::from_ptr(msg).to_bytes() }; + let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) }; + + let s = String::from_utf8_lossy(c_slice); + drop(catch_unwind(|| callback(err, &s))); + } + + let rc = if let Some(f) = callback { + ffi::sqlite3_config( + ffi::SQLITE_CONFIG_LOG, + log_callback as extern "C" fn(_, _, _), + f as *mut c_void, + ) + } else { + let nullptr: *mut c_void = ptr::null_mut(); + ffi::sqlite3_config(ffi::SQLITE_CONFIG_LOG, nullptr, nullptr) + }; + + if rc == ffi::SQLITE_OK { + Ok(()) + } else { + Err(error_from_sqlite_code(rc, None)) + } +} + +/// Write a message into the error log established by +/// `config_log`. +#[inline] +pub fn log(err_code: c_int, msg: &str) { + let msg = CString::new(msg).expect("SQLite log messages cannot contain embedded zeroes"); + unsafe { + ffi::sqlite3_log(err_code, b"%s\0" as *const _ as *const c_char, msg.as_ptr()); + } +} + +impl Connection { + /// Register or clear a callback function that can be + /// used for tracing the execution of SQL statements. + /// + /// Prepared statement placeholders are replaced/logged with their assigned + /// values. There can only be a single tracer defined for each database + /// connection. Setting a new tracer clears the old one. + pub fn trace(&mut self, trace_fn: Option) { + unsafe extern "C" fn trace_callback(p_arg: *mut c_void, z_sql: *const c_char) { + let trace_fn: fn(&str) = mem::transmute(p_arg); + let c_slice = CStr::from_ptr(z_sql).to_bytes(); + let s = String::from_utf8_lossy(c_slice); + drop(catch_unwind(|| trace_fn(&s))); + } + + let c = self.db.borrow_mut(); + match trace_fn { + Some(f) => unsafe { + ffi::sqlite3_trace(c.db(), Some(trace_callback), f as *mut c_void); + }, + None => unsafe { + ffi::sqlite3_trace(c.db(), None, ptr::null_mut()); + }, + } + } + + /// Register or clear a callback function that can be + /// used for profiling the execution of SQL statements. + /// + /// There can only be a single profiler defined for each database + /// connection. Setting a new profiler clears the old one. + pub fn profile(&mut self, profile_fn: Option) { + unsafe extern "C" fn profile_callback( + p_arg: *mut c_void, + z_sql: *const c_char, + nanoseconds: u64, + ) { + let profile_fn: fn(&str, Duration) = mem::transmute(p_arg); + let c_slice = CStr::from_ptr(z_sql).to_bytes(); + let s = String::from_utf8_lossy(c_slice); + const NANOS_PER_SEC: u64 = 1_000_000_000; + + let duration = Duration::new( + nanoseconds / NANOS_PER_SEC, + (nanoseconds % NANOS_PER_SEC) as u32, + ); + drop(catch_unwind(|| profile_fn(&s, duration))); + } + + let c = self.db.borrow_mut(); + match profile_fn { + Some(f) => unsafe { + ffi::sqlite3_profile(c.db(), Some(profile_callback), f as *mut c_void) + }, + None => unsafe { ffi::sqlite3_profile(c.db(), None, ptr::null_mut()) }, + }; + } + + // TODO sqlite3_trace_v2 (https://sqlite.org/c3ref/trace_v2.html) // 3.14.0, #977 +} + +#[cfg(test)] +mod test { + use lazy_static::lazy_static; + use std::sync::Mutex; + use std::time::Duration; + + use crate::{Connection, Result}; + + #[test] + fn test_trace() -> Result<()> { + lazy_static! { + static ref TRACED_STMTS: Mutex> = Mutex::new(Vec::new()); + } + fn tracer(s: &str) { + let mut traced_stmts = TRACED_STMTS.lock().unwrap(); + traced_stmts.push(s.to_owned()); + } + + let mut db = Connection::open_in_memory()?; + db.trace(Some(tracer)); + { + let _ = db.query_row("SELECT ?", [1i32], |_| Ok(())); + let _ = db.query_row("SELECT ?", ["hello"], |_| Ok(())); + } + db.trace(None); + { + let _ = db.query_row("SELECT ?", [2i32], |_| Ok(())); + let _ = db.query_row("SELECT ?", ["goodbye"], |_| Ok(())); + } + + let traced_stmts = TRACED_STMTS.lock().unwrap(); + assert_eq!(traced_stmts.len(), 2); + assert_eq!(traced_stmts[0], "SELECT 1"); + assert_eq!(traced_stmts[1], "SELECT 'hello'"); + Ok(()) + } + + #[test] + fn test_profile() -> Result<()> { + lazy_static! { + static ref PROFILED: Mutex> = Mutex::new(Vec::new()); + } + fn profiler(s: &str, d: Duration) { + let mut profiled = PROFILED.lock().unwrap(); + profiled.push((s.to_owned(), d)); + } + + let mut db = Connection::open_in_memory()?; + db.profile(Some(profiler)); + db.execute_batch("PRAGMA application_id = 1")?; + db.profile(None); + db.execute_batch("PRAGMA application_id = 2")?; + + let profiled = PROFILED.lock().unwrap(); + assert_eq!(profiled.len(), 1); + assert_eq!(profiled[0].0, "PRAGMA application_id = 1"); + Ok(()) + } +} diff --git a/src/transaction.rs b/src/transaction.rs new file mode 100644 index 0000000..2c4c6c0 --- /dev/null +++ b/src/transaction.rs @@ -0,0 +1,759 @@ +use crate::{Connection, Result}; +use std::ops::Deref; + +/// Options for transaction behavior. See [BEGIN +/// TRANSACTION](http://www.sqlite.org/lang_transaction.html) for details. +#[derive(Copy, Clone)] +#[non_exhaustive] +pub enum TransactionBehavior { + /// DEFERRED means that the transaction does not actually start until the + /// database is first accessed. + Deferred, + /// IMMEDIATE cause the database connection to start a new write + /// immediately, without waiting for a writes statement. + Immediate, + /// EXCLUSIVE prevents other database connections from reading the database + /// while the transaction is underway. + Exclusive, +} + +/// Options for how a Transaction or Savepoint should behave when it is dropped. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum DropBehavior { + /// Roll back the changes. This is the default. + Rollback, + + /// Commit the changes. + Commit, + + /// Do not commit or roll back changes - this will leave the transaction or + /// savepoint open, so should be used with care. + Ignore, + + /// Panic. Used to enforce intentional behavior during development. + Panic, +} + +/// Represents a transaction on a database connection. +/// +/// ## Note +/// +/// Transactions will roll back by default. Use `commit` method to explicitly +/// commit the transaction, or use `set_drop_behavior` to change what happens +/// when the transaction is dropped. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } +/// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } +/// fn perform_queries(conn: &mut Connection) -> Result<()> { +/// let tx = conn.transaction()?; +/// +/// do_queries_part_1(&tx)?; // tx causes rollback if this fails +/// do_queries_part_2(&tx)?; // tx causes rollback if this fails +/// +/// tx.commit() +/// } +/// ``` +#[derive(Debug)] +pub struct Transaction<'conn> { + conn: &'conn Connection, + drop_behavior: DropBehavior, +} + +/// Represents a savepoint on a database connection. +/// +/// ## Note +/// +/// Savepoints will roll back by default. Use `commit` method to explicitly +/// commit the savepoint, or use `set_drop_behavior` to change what happens +/// when the savepoint is dropped. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } +/// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } +/// fn perform_queries(conn: &mut Connection) -> Result<()> { +/// let sp = conn.savepoint()?; +/// +/// do_queries_part_1(&sp)?; // sp causes rollback if this fails +/// do_queries_part_2(&sp)?; // sp causes rollback if this fails +/// +/// sp.commit() +/// } +/// ``` +#[derive(Debug)] +pub struct Savepoint<'conn> { + conn: &'conn Connection, + name: String, + depth: u32, + drop_behavior: DropBehavior, + committed: bool, +} + +impl Transaction<'_> { + /// Begin a new transaction. Cannot be nested; see `savepoint` for nested + /// transactions. + /// + /// Even though we don't mutate the connection, we take a `&mut Connection` + /// so as to prevent nested transactions on the same connection. For cases + /// where this is unacceptable, [`Transaction::new_unchecked`] is available. + #[inline] + pub fn new(conn: &mut Connection, behavior: TransactionBehavior) -> Result> { + Self::new_unchecked(conn, behavior) + } + + /// Begin a new transaction, failing if a transaction is open. + /// + /// If a transaction is already open, this will return an error. Where + /// possible, [`Transaction::new`] should be preferred, as it provides a + /// compile-time guarantee that transactions are not nested. + #[inline] + pub fn new_unchecked( + conn: &Connection, + behavior: TransactionBehavior, + ) -> Result> { + let query = match behavior { + TransactionBehavior::Deferred => "BEGIN DEFERRED", + TransactionBehavior::Immediate => "BEGIN IMMEDIATE", + TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE", + }; + conn.execute_batch(query).map(move |_| Transaction { + conn, + drop_behavior: DropBehavior::Rollback, + }) + } + + /// Starts a new [savepoint](http://www.sqlite.org/lang_savepoint.html), allowing nested + /// transactions. + /// + /// ## Note + /// + /// Just like outer level transactions, savepoint transactions rollback by + /// default. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn perform_queries_part_1_succeeds(_conn: &Connection) -> bool { true } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let mut tx = conn.transaction()?; + /// + /// { + /// let sp = tx.savepoint()?; + /// if perform_queries_part_1_succeeds(&sp) { + /// sp.commit()?; + /// } + /// // otherwise, sp will rollback + /// } + /// + /// tx.commit() + /// } + /// ``` + #[inline] + pub fn savepoint(&mut self) -> Result> { + Savepoint::with_depth(self.conn, 1) + } + + /// Create a new savepoint with a custom savepoint name. See `savepoint()`. + #[inline] + pub fn savepoint_with_name>(&mut self, name: T) -> Result> { + Savepoint::with_depth_and_name(self.conn, 1, name) + } + + /// Get the current setting for what happens to the transaction when it is + /// dropped. + #[inline] + #[must_use] + pub fn drop_behavior(&self) -> DropBehavior { + self.drop_behavior + } + + /// Configure the transaction to perform the specified action when it is + /// dropped. + #[inline] + pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) { + self.drop_behavior = drop_behavior; + } + + /// A convenience method which consumes and commits a transaction. + #[inline] + pub fn commit(mut self) -> Result<()> { + self.commit_() + } + + #[inline] + fn commit_(&mut self) -> Result<()> { + self.conn.execute_batch("COMMIT")?; + Ok(()) + } + + /// A convenience method which consumes and rolls back a transaction. + #[inline] + pub fn rollback(mut self) -> Result<()> { + self.rollback_() + } + + #[inline] + fn rollback_(&mut self) -> Result<()> { + self.conn.execute_batch("ROLLBACK")?; + Ok(()) + } + + /// Consumes the transaction, committing or rolling back according to the + /// current setting (see `drop_behavior`). + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + #[inline] + pub fn finish(mut self) -> Result<()> { + self.finish_() + } + + #[inline] + fn finish_(&mut self) -> Result<()> { + if self.conn.is_autocommit() { + return Ok(()); + } + match self.drop_behavior() { + DropBehavior::Commit => self.commit_().or_else(|_| self.rollback_()), + DropBehavior::Rollback => self.rollback_(), + DropBehavior::Ignore => Ok(()), + DropBehavior::Panic => panic!("Transaction dropped unexpectedly."), + } + } +} + +impl Deref for Transaction<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.conn + } +} + +#[allow(unused_must_use)] +impl Drop for Transaction<'_> { + #[inline] + fn drop(&mut self) { + self.finish_(); + } +} + +impl Savepoint<'_> { + #[inline] + fn with_depth_and_name>( + conn: &Connection, + depth: u32, + name: T, + ) -> Result> { + let name = name.into(); + conn.execute_batch(&format!("SAVEPOINT {}", name)) + .map(|_| Savepoint { + conn, + name, + depth, + drop_behavior: DropBehavior::Rollback, + committed: false, + }) + } + + #[inline] + fn with_depth(conn: &Connection, depth: u32) -> Result> { + let name = format!("_rusqlite_sp_{}", depth); + Savepoint::with_depth_and_name(conn, depth, name) + } + + /// Begin a new savepoint. Can be nested. + #[inline] + pub fn new(conn: &mut Connection) -> Result> { + Savepoint::with_depth(conn, 0) + } + + /// Begin a new savepoint with a user-provided savepoint name. + #[inline] + pub fn with_name>(conn: &mut Connection, name: T) -> Result> { + Savepoint::with_depth_and_name(conn, 0, name) + } + + /// Begin a nested savepoint. + #[inline] + pub fn savepoint(&mut self) -> Result> { + Savepoint::with_depth(self.conn, self.depth + 1) + } + + /// Begin a nested savepoint with a user-provided savepoint name. + #[inline] + pub fn savepoint_with_name>(&mut self, name: T) -> Result> { + Savepoint::with_depth_and_name(self.conn, self.depth + 1, name) + } + + /// Get the current setting for what happens to the savepoint when it is + /// dropped. + #[inline] + #[must_use] + pub fn drop_behavior(&self) -> DropBehavior { + self.drop_behavior + } + + /// Configure the savepoint to perform the specified action when it is + /// dropped. + #[inline] + pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) { + self.drop_behavior = drop_behavior; + } + + /// A convenience method which consumes and commits a savepoint. + #[inline] + pub fn commit(mut self) -> Result<()> { + self.commit_() + } + + #[inline] + fn commit_(&mut self) -> Result<()> { + self.conn.execute_batch(&format!("RELEASE {}", self.name))?; + self.committed = true; + Ok(()) + } + + /// A convenience method which rolls back a savepoint. + /// + /// ## Note + /// + /// Unlike `Transaction`s, savepoints remain active after they have been + /// rolled back, and can be rolled back again or committed. + #[inline] + pub fn rollback(&mut self) -> Result<()> { + self.conn + .execute_batch(&format!("ROLLBACK TO {}", self.name)) + } + + /// Consumes the savepoint, committing or rolling back according to the + /// current setting (see `drop_behavior`). + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + #[inline] + pub fn finish(mut self) -> Result<()> { + self.finish_() + } + + #[inline] + fn finish_(&mut self) -> Result<()> { + if self.committed { + return Ok(()); + } + match self.drop_behavior() { + DropBehavior::Commit => self.commit_().or_else(|_| self.rollback()), + DropBehavior::Rollback => self.rollback(), + DropBehavior::Ignore => Ok(()), + DropBehavior::Panic => panic!("Savepoint dropped unexpectedly."), + } + } +} + +impl Deref for Savepoint<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.conn + } +} + +#[allow(unused_must_use)] +impl Drop for Savepoint<'_> { + #[inline] + fn drop(&mut self) { + self.finish_(); + } +} + +/// Transaction state of a database +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[non_exhaustive] +#[cfg(feature = "modern_sqlite")] // 3.37.0 +#[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] +pub enum TransactionState { + /// Equivalent to SQLITE_TXN_NONE + None, + /// Equivalent to SQLITE_TXN_READ + Read, + /// Equivalent to SQLITE_TXN_WRITE + Write, +} + +impl Connection { + /// Begin a new transaction with the default behavior (DEFERRED). + /// + /// The transaction defaults to rolling back when it is dropped. If you + /// want the transaction to commit, you must call + /// [`commit`](Transaction::commit) or + /// [`set_drop_behavior(DropBehavior::Commit)`](Transaction::set_drop_behavior). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let tx = conn.transaction()?; + /// + /// do_queries_part_1(&tx)?; // tx causes rollback if this fails + /// do_queries_part_2(&tx)?; // tx causes rollback if this fails + /// + /// tx.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn transaction(&mut self) -> Result> { + Transaction::new(self, TransactionBehavior::Deferred) + } + + /// Begin a new transaction with a specified behavior. + /// + /// See [`transaction`](Connection::transaction). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn transaction_with_behavior( + &mut self, + behavior: TransactionBehavior, + ) -> Result> { + Transaction::new(self, behavior) + } + + /// Begin a new transaction with the default behavior (DEFERRED). + /// + /// Attempt to open a nested transaction will result in a SQLite error. + /// `Connection::transaction` prevents this at compile time by taking `&mut + /// self`, but `Connection::unchecked_transaction()` may be used to defer + /// the checking until runtime. + /// + /// See [`Connection::transaction`] and [`Transaction::new_unchecked`] + /// (which can be used if the default transaction behavior is undesirable). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # use std::rc::Rc; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: Rc) -> Result<()> { + /// let tx = conn.unchecked_transaction()?; + /// + /// do_queries_part_1(&tx)?; // tx causes rollback if this fails + /// do_queries_part_2(&tx)?; // tx causes rollback if this fails + /// + /// tx.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. The specific + /// error returned if transactions are nested is currently unspecified. + pub fn unchecked_transaction(&self) -> Result> { + Transaction::new_unchecked(self, TransactionBehavior::Deferred) + } + + /// Begin a new savepoint with the default behavior (DEFERRED). + /// + /// The savepoint defaults to rolling back when it is dropped. If you want + /// the savepoint to commit, you must call [`commit`](Savepoint::commit) or + /// [`set_drop_behavior(DropBehavior::Commit)`](Savepoint:: + /// set_drop_behavior). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let sp = conn.savepoint()?; + /// + /// do_queries_part_1(&sp)?; // sp causes rollback if this fails + /// do_queries_part_2(&sp)?; // sp causes rollback if this fails + /// + /// sp.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn savepoint(&mut self) -> Result> { + Savepoint::new(self) + } + + /// Begin a new savepoint with a specified name. + /// + /// See [`savepoint`](Connection::savepoint). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn savepoint_with_name>(&mut self, name: T) -> Result> { + Savepoint::with_name(self, name) + } + + /// Determine the transaction state of a database + #[cfg(feature = "modern_sqlite")] // 3.37.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn transaction_state( + &self, + db_name: Option>, + ) -> Result { + self.db.borrow().txn_state(db_name) + } +} + +#[cfg(test)] +mod test { + use super::DropBehavior; + use crate::{Connection, Error, Result}; + + fn checked_memory_handle() -> Result { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (x INTEGER)")?; + Ok(db) + } + + #[test] + fn test_drop() -> Result<()> { + let mut db = checked_memory_handle()?; + { + let tx = db.transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(1)")?; + // default: rollback + } + { + let mut tx = db.transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(2)")?; + tx.set_drop_behavior(DropBehavior::Commit) + } + { + let tx = db.transaction()?; + assert_eq!( + 2i32, + tx.query_row::("SELECT SUM(x) FROM foo", [], |r| r.get(0))? + ); + } + Ok(()) + } + fn assert_nested_tx_error(e: Error) { + if let Error::SqliteFailure(e, Some(m)) = &e { + assert_eq!(e.extended_code, crate::ffi::SQLITE_ERROR); + // FIXME: Not ideal... + assert_eq!(e.code, crate::ErrorCode::Unknown); + assert!(m.contains("transaction")); + } else { + panic!("Unexpected error type: {:?}", e); + } + } + + #[test] + fn test_unchecked_nesting() -> Result<()> { + let db = checked_memory_handle()?; + + { + let tx = db.unchecked_transaction()?; + let e = tx.unchecked_transaction().unwrap_err(); + assert_nested_tx_error(e); + // default: rollback + } + { + let tx = db.unchecked_transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(1)")?; + // Ensure this doesn't interfere with ongoing transaction + let e = tx.unchecked_transaction().unwrap_err(); + assert_nested_tx_error(e); + + tx.execute_batch("INSERT INTO foo VALUES(1)")?; + tx.commit()?; + } + + assert_eq!( + 2i32, + db.query_row::("SELECT SUM(x) FROM foo", [], |r| r.get(0))? + ); + Ok(()) + } + + #[test] + fn test_explicit_rollback_commit() -> Result<()> { + let mut db = checked_memory_handle()?; + { + let mut tx = db.transaction()?; + { + let mut sp = tx.savepoint()?; + sp.execute_batch("INSERT INTO foo VALUES(1)")?; + sp.rollback()?; + sp.execute_batch("INSERT INTO foo VALUES(2)")?; + sp.commit()?; + } + tx.commit()?; + } + { + let tx = db.transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(4)")?; + tx.commit()?; + } + { + let tx = db.transaction()?; + assert_eq!( + 6i32, + tx.query_row::("SELECT SUM(x) FROM foo", [], |r| r.get(0))? + ); + } + Ok(()) + } + + #[test] + fn test_savepoint() -> Result<()> { + let mut db = checked_memory_handle()?; + { + let mut tx = db.transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(1)")?; + assert_current_sum(1, &tx)?; + tx.set_drop_behavior(DropBehavior::Commit); + { + let mut sp1 = tx.savepoint()?; + sp1.execute_batch("INSERT INTO foo VALUES(2)")?; + assert_current_sum(3, &sp1)?; + // will rollback sp1 + { + let mut sp2 = sp1.savepoint()?; + sp2.execute_batch("INSERT INTO foo VALUES(4)")?; + assert_current_sum(7, &sp2)?; + // will rollback sp2 + { + let sp3 = sp2.savepoint()?; + sp3.execute_batch("INSERT INTO foo VALUES(8)")?; + assert_current_sum(15, &sp3)?; + sp3.commit()?; + // committed sp3, but will be erased by sp2 rollback + } + assert_current_sum(15, &sp2)?; + } + assert_current_sum(3, &sp1)?; + } + assert_current_sum(1, &tx)?; + } + assert_current_sum(1, &db)?; + Ok(()) + } + + #[test] + fn test_ignore_drop_behavior() -> Result<()> { + let mut db = checked_memory_handle()?; + + let mut tx = db.transaction()?; + { + let mut sp1 = tx.savepoint()?; + insert(1, &sp1)?; + sp1.rollback()?; + insert(2, &sp1)?; + { + let mut sp2 = sp1.savepoint()?; + sp2.set_drop_behavior(DropBehavior::Ignore); + insert(4, &sp2)?; + } + assert_current_sum(6, &sp1)?; + sp1.commit()?; + } + assert_current_sum(6, &tx)?; + Ok(()) + } + + #[test] + fn test_savepoint_names() -> Result<()> { + let mut db = checked_memory_handle()?; + + { + let mut sp1 = db.savepoint_with_name("my_sp")?; + insert(1, &sp1)?; + assert_current_sum(1, &sp1)?; + { + let mut sp2 = sp1.savepoint_with_name("my_sp")?; + sp2.set_drop_behavior(DropBehavior::Commit); + insert(2, &sp2)?; + assert_current_sum(3, &sp2)?; + sp2.rollback()?; + assert_current_sum(1, &sp2)?; + insert(4, &sp2)?; + } + assert_current_sum(5, &sp1)?; + sp1.rollback()?; + { + let mut sp2 = sp1.savepoint_with_name("my_sp")?; + sp2.set_drop_behavior(DropBehavior::Ignore); + insert(8, &sp2)?; + } + assert_current_sum(8, &sp1)?; + sp1.commit()?; + } + assert_current_sum(8, &db)?; + Ok(()) + } + + #[test] + fn test_rc() -> Result<()> { + use std::rc::Rc; + let mut conn = Connection::open_in_memory()?; + let rc_txn = Rc::new(conn.transaction()?); + + // This will compile only if Transaction is Debug + Rc::try_unwrap(rc_txn).unwrap(); + Ok(()) + } + + fn insert(x: i32, conn: &Connection) -> Result { + conn.execute("INSERT INTO foo VALUES(?)", [x]) + } + + fn assert_current_sum(x: i32, conn: &Connection) -> Result<()> { + let i = conn.query_row::("SELECT SUM(x) FROM foo", [], |r| r.get(0))?; + assert_eq!(x, i); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn txn_state() -> Result<()> { + use super::TransactionState; + use crate::DatabaseName; + let db = Connection::open_in_memory()?; + assert_eq!( + TransactionState::None, + db.transaction_state(Some(DatabaseName::Main))? + ); + assert_eq!(TransactionState::None, db.transaction_state(None)?); + db.execute_batch("BEGIN")?; + assert_eq!(TransactionState::None, db.transaction_state(None)?); + let _: i32 = db.pragma_query_value(None, "user_version", |row| row.get(0))?; + assert_eq!(TransactionState::Read, db.transaction_state(None)?); + db.pragma_update(None, "user_version", 1)?; + assert_eq!(TransactionState::Write, db.transaction_state(None)?); + db.execute_batch("ROLLBACK")?; + Ok(()) + } +} diff --git a/src/types/chrono.rs b/src/types/chrono.rs new file mode 100644 index 0000000..6bfc2f4 --- /dev/null +++ b/src/types/chrono.rs @@ -0,0 +1,323 @@ +//! Convert most of the [Time Strings](http://sqlite.org/lang_datefunc.html) to chrono types. + +use chrono::{DateTime, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; + +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; + +/// ISO 8601 calendar date without timezone => "YYYY-MM-DD" +impl ToSql for NaiveDate { + #[inline] + fn to_sql(&self) -> Result> { + let date_str = self.format("%F").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "YYYY-MM-DD" => ISO 8601 calendar date without timezone. +impl FromSql for NaiveDate { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value + .as_str() + .and_then(|s| match NaiveDate::parse_from_str(s, "%F") { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + }) + } +} + +/// ISO 8601 time without timezone => "HH:MM:SS.SSS" +impl ToSql for NaiveTime { + #[inline] + fn to_sql(&self) -> Result> { + let date_str = self.format("%T%.f").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "HH:MM"/"HH:MM:SS"/"HH:MM:SS.SSS" => ISO 8601 time without timezone. +impl FromSql for NaiveTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_str().and_then(|s| { + let fmt = match s.len() { + 5 => "%H:%M", + 8 => "%T", + _ => "%T%.f", + }; + match NaiveTime::parse_from_str(s, fmt) { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + } + }) + } +} + +/// ISO 8601 combined date and time without timezone => +/// "YYYY-MM-DD HH:MM:SS.SSS" +impl ToSql for NaiveDateTime { + #[inline] + fn to_sql(&self) -> Result> { + let date_str = self.format("%F %T%.f").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "YYYY-MM-DD HH:MM:SS"/"YYYY-MM-DD HH:MM:SS.SSS" => ISO 8601 combined date +/// and time without timezone. ("YYYY-MM-DDTHH:MM:SS"/"YYYY-MM-DDTHH:MM:SS.SSS" +/// also supported) +impl FromSql for NaiveDateTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_str().and_then(|s| { + let fmt = if s.len() >= 11 && s.as_bytes()[10] == b'T' { + "%FT%T%.f" + } else { + "%F %T%.f" + }; + + match NaiveDateTime::parse_from_str(s, fmt) { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + } + }) + } +} + +/// UTC time => UTC RFC3339 timestamp +/// ("YYYY-MM-DD HH:MM:SS.SSS+00:00"). +impl ToSql for DateTime { + #[inline] + fn to_sql(&self) -> Result> { + let date_str = self.format("%F %T%.f%:z").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// Local time => UTC RFC3339 timestamp +/// ("YYYY-MM-DD HH:MM:SS.SSS+00:00"). +impl ToSql for DateTime { + #[inline] + fn to_sql(&self) -> Result> { + let date_str = self.with_timezone(&Utc).format("%F %T%.f%:z").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// Date and time with time zone => RFC3339 timestamp +/// ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM"). +impl ToSql for DateTime { + #[inline] + fn to_sql(&self) -> Result> { + let date_str = self.format("%F %T%.f%:z").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// RFC3339 ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM") into `DateTime`. +impl FromSql for DateTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + { + // Try to parse value as rfc3339 first. + let s = value.as_str()?; + + let fmt = if s.len() >= 11 && s.as_bytes()[10] == b'T' { + "%FT%T%.f%#z" + } else { + "%F %T%.f%#z" + }; + + if let Ok(dt) = DateTime::parse_from_str(s, fmt) { + return Ok(dt.with_timezone(&Utc)); + } + } + + // Couldn't parse as rfc3339 - fall back to NaiveDateTime. + NaiveDateTime::column_result(value).map(|dt| Utc.from_utc_datetime(&dt)) + } +} + +/// RFC3339 ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM") into `DateTime`. +impl FromSql for DateTime { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let utc_dt = DateTime::::column_result(value)?; + Ok(utc_dt.with_timezone(&Local)) + } +} + +/// RFC3339 ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM") into `DateTime`. +impl FromSql for DateTime { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let s = String::column_result(value)?; + Self::parse_from_rfc3339(s.as_str()) + .or_else(|_| Self::parse_from_str(s.as_str(), "%F %T%.f%:z")) + .map_err(|e| FromSqlError::Other(Box::new(e))) + } +} + +#[cfg(test)] +mod test { + use crate::{ + types::{FromSql, ValueRef}, + Connection, Result, + }; + use chrono::{ + DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc, + }; + + fn checked_memory_handle() -> Result { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (t TEXT, i INTEGER, f FLOAT, b BLOB)")?; + Ok(db) + } + + #[test] + fn test_naive_date() -> Result<()> { + let db = checked_memory_handle()?; + let date = NaiveDate::from_ymd(2016, 2, 23); + db.execute("INSERT INTO foo (t) VALUES (?)", [date])?; + + let s: String = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!("2016-02-23", s); + let t: NaiveDate = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(date, t); + Ok(()) + } + + #[test] + fn test_naive_time() -> Result<()> { + let db = checked_memory_handle()?; + let time = NaiveTime::from_hms(23, 56, 4); + db.execute("INSERT INTO foo (t) VALUES (?)", [time])?; + + let s: String = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!("23:56:04", s); + let v: NaiveTime = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(time, v); + Ok(()) + } + + #[test] + fn test_naive_date_time() -> Result<()> { + let db = checked_memory_handle()?; + let date = NaiveDate::from_ymd(2016, 2, 23); + let time = NaiveTime::from_hms(23, 56, 4); + let dt = NaiveDateTime::new(date, time); + + db.execute("INSERT INTO foo (t) VALUES (?)", [dt])?; + + let s: String = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!("2016-02-23 23:56:04", s); + let v: NaiveDateTime = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(dt, v); + + db.execute("UPDATE foo set b = datetime(t)", [])?; // "YYYY-MM-DD HH:MM:SS" + let hms: NaiveDateTime = db.query_row("SELECT b FROM foo", [], |r| r.get(0))?; + assert_eq!(dt, hms); + Ok(()) + } + + #[test] + fn test_date_time_utc() -> Result<()> { + let db = checked_memory_handle()?; + let date = NaiveDate::from_ymd(2016, 2, 23); + let time = NaiveTime::from_hms_milli(23, 56, 4, 789); + let dt = NaiveDateTime::new(date, time); + let utc = Utc.from_utc_datetime(&dt); + + db.execute("INSERT INTO foo (t) VALUES (?)", [utc])?; + + let s: String = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!("2016-02-23 23:56:04.789+00:00", s); + + let v1: DateTime = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(utc, v1); + + let v2: DateTime = + db.query_row("SELECT '2016-02-23 23:56:04.789'", [], |r| r.get(0))?; + assert_eq!(utc, v2); + + let v3: DateTime = db.query_row("SELECT '2016-02-23 23:56:04'", [], |r| r.get(0))?; + assert_eq!(utc - Duration::milliseconds(789), v3); + + let v4: DateTime = + db.query_row("SELECT '2016-02-23 23:56:04.789+00:00'", [], |r| r.get(0))?; + assert_eq!(utc, v4); + Ok(()) + } + + #[test] + fn test_date_time_local() -> Result<()> { + let db = checked_memory_handle()?; + let date = NaiveDate::from_ymd(2016, 2, 23); + let time = NaiveTime::from_hms_milli(23, 56, 4, 789); + let dt = NaiveDateTime::new(date, time); + let local = Local.from_local_datetime(&dt).single().unwrap(); + + db.execute("INSERT INTO foo (t) VALUES (?)", [local])?; + + // Stored string should be in UTC + let s: String = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert!(s.ends_with("+00:00")); + + let v: DateTime = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(local, v); + Ok(()) + } + + #[test] + fn test_date_time_fixed() -> Result<()> { + let db = checked_memory_handle()?; + let time = DateTime::parse_from_rfc3339("2020-04-07T11:23:45+04:00").unwrap(); + + db.execute("INSERT INTO foo (t) VALUES (?)", [time])?; + + // Stored string should preserve timezone offset + let s: String = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert!(s.ends_with("+04:00")); + + let v: DateTime = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(time.offset(), v.offset()); + assert_eq!(time, v); + Ok(()) + } + + #[test] + fn test_sqlite_functions() -> Result<()> { + let db = checked_memory_handle()?; + let result: Result = db.query_row("SELECT CURRENT_TIME", [], |r| r.get(0)); + assert!(result.is_ok()); + let result: Result = db.query_row("SELECT CURRENT_DATE", [], |r| r.get(0)); + assert!(result.is_ok()); + let result: Result = + db.query_row("SELECT CURRENT_TIMESTAMP", [], |r| r.get(0)); + assert!(result.is_ok()); + let result: Result> = + db.query_row("SELECT CURRENT_TIMESTAMP", [], |r| r.get(0)); + assert!(result.is_ok()); + Ok(()) + } + + #[test] + fn test_naive_date_time_param() -> Result<()> { + let db = checked_memory_handle()?; + let result: Result = db.query_row("SELECT 1 WHERE ? BETWEEN datetime('now', '-1 minute') AND datetime('now', '+1 minute')", [Utc::now().naive_utc()], |r| r.get(0)); + assert!(result.is_ok()); + Ok(()) + } + + #[test] + fn test_date_time_param() -> Result<()> { + let db = checked_memory_handle()?; + let result: Result = db.query_row("SELECT 1 WHERE ? BETWEEN datetime('now', '-1 minute') AND datetime('now', '+1 minute')", [Utc::now()], |r| r.get(0)); + assert!(result.is_ok()); + Ok(()) + } + + #[test] + fn test_lenient_parse_timezone() { + assert!(DateTime::::column_result(ValueRef::Text(b"1970-01-01T00:00:00Z")).is_ok()); + assert!(DateTime::::column_result(ValueRef::Text(b"1970-01-01T00:00:00+00")).is_ok()); + } +} diff --git a/src/types/from_sql.rs b/src/types/from_sql.rs new file mode 100644 index 0000000..b95a378 --- /dev/null +++ b/src/types/from_sql.rs @@ -0,0 +1,276 @@ +use super::{Value, ValueRef}; +use std::convert::TryInto; +use std::error::Error; +use std::fmt; + +/// Enum listing possible errors from [`FromSql`] trait. +#[derive(Debug)] +#[non_exhaustive] +pub enum FromSqlError { + /// Error when an SQLite value is requested, but the type of the result + /// cannot be converted to the requested Rust type. + InvalidType, + + /// Error when the i64 value returned by SQLite cannot be stored into the + /// requested type. + OutOfRange(i64), + + /// Error when the blob result returned by SQLite cannot be stored into the + /// requested type due to a size mismatch. + InvalidBlobSize { + /// The expected size of the blob. + expected_size: usize, + /// The actual size of the blob that was returned. + blob_size: usize, + }, + + /// An error case available for implementors of the [`FromSql`] trait. + Other(Box), +} + +impl PartialEq for FromSqlError { + fn eq(&self, other: &FromSqlError) -> bool { + match (self, other) { + (FromSqlError::InvalidType, FromSqlError::InvalidType) => true, + (FromSqlError::OutOfRange(n1), FromSqlError::OutOfRange(n2)) => n1 == n2, + ( + FromSqlError::InvalidBlobSize { + expected_size: es1, + blob_size: bs1, + }, + FromSqlError::InvalidBlobSize { + expected_size: es2, + blob_size: bs2, + }, + ) => es1 == es2 && bs1 == bs2, + (..) => false, + } + } +} + +impl fmt::Display for FromSqlError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + FromSqlError::InvalidType => write!(f, "Invalid type"), + FromSqlError::OutOfRange(i) => write!(f, "Value {} out of range", i), + FromSqlError::InvalidBlobSize { + expected_size, + blob_size, + } => { + write!( + f, + "Cannot read {} byte value out of {} byte blob", + expected_size, blob_size + ) + } + FromSqlError::Other(ref err) => err.fmt(f), + } + } +} + +impl Error for FromSqlError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + if let FromSqlError::Other(ref err) = self { + Some(&**err) + } else { + None + } + } +} + +/// Result type for implementors of the [`FromSql`] trait. +pub type FromSqlResult = Result; + +/// A trait for types that can be created from a SQLite value. +pub trait FromSql: Sized { + /// Converts SQLite value into Rust value. + fn column_result(value: ValueRef<'_>) -> FromSqlResult; +} + +macro_rules! from_sql_integral( + ($t:ident) => ( + impl FromSql for $t { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let i = i64::column_result(value)?; + i.try_into().map_err(|_| FromSqlError::OutOfRange(i)) + } + } + ) +); + +from_sql_integral!(i8); +from_sql_integral!(i16); +from_sql_integral!(i32); +// from_sql_integral!(i64); // Not needed because the native type is i64. +from_sql_integral!(isize); +from_sql_integral!(u8); +from_sql_integral!(u16); +from_sql_integral!(u32); +from_sql_integral!(u64); +from_sql_integral!(usize); + +impl FromSql for i64 { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_i64() + } +} + +impl FromSql for f32 { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + match value { + ValueRef::Integer(i) => Ok(i as f32), + ValueRef::Real(f) => Ok(f as f32), + _ => Err(FromSqlError::InvalidType), + } + } +} + +impl FromSql for f64 { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + match value { + ValueRef::Integer(i) => Ok(i as f64), + ValueRef::Real(f) => Ok(f), + _ => Err(FromSqlError::InvalidType), + } + } +} + +impl FromSql for bool { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + i64::column_result(value).map(|i| i != 0) + } +} + +impl FromSql for String { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_str().map(ToString::to_string) + } +} + +impl FromSql for Box { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_str().map(Into::into) + } +} + +impl FromSql for std::rc::Rc { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_str().map(Into::into) + } +} + +impl FromSql for std::sync::Arc { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_str().map(Into::into) + } +} + +impl FromSql for Vec { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_blob().map(<[u8]>::to_vec) + } +} + +impl FromSql for [u8; N] { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let slice = value.as_blob()?; + slice.try_into().map_err(|_| FromSqlError::InvalidBlobSize { + expected_size: N, + blob_size: slice.len(), + }) + } +} + +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +impl FromSql for i128 { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let bytes = <[u8; 16]>::column_result(value)?; + Ok(i128::from_be_bytes(bytes) ^ (1_i128 << 127)) + } +} + +#[cfg(feature = "uuid")] +#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] +impl FromSql for uuid::Uuid { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let bytes = <[u8; 16]>::column_result(value)?; + Ok(uuid::Uuid::from_u128(u128::from_be_bytes(bytes))) + } +} + +impl FromSql for Option { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + match value { + ValueRef::Null => Ok(None), + _ => FromSql::column_result(value).map(Some), + } + } +} + +impl FromSql for Value { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + Ok(value.into()) + } +} + +#[cfg(test)] +mod test { + use super::FromSql; + use crate::{Connection, Error, Result}; + + #[test] + fn test_integral_ranges() -> Result<()> { + let db = Connection::open_in_memory()?; + + fn check_ranges(db: &Connection, out_of_range: &[i64], in_range: &[i64]) + where + T: Into + FromSql + std::fmt::Debug, + { + for n in out_of_range { + let err = db + .query_row("SELECT ?", &[n], |r| r.get::<_, T>(0)) + .unwrap_err(); + match err { + Error::IntegralValueOutOfRange(_, value) => assert_eq!(*n, value), + _ => panic!("unexpected error: {}", err), + } + } + for n in in_range { + assert_eq!( + *n, + db.query_row("SELECT ?", &[n], |r| r.get::<_, T>(0)) + .unwrap() + .into() + ); + } + } + + check_ranges::(&db, &[-129, 128], &[-128, 0, 1, 127]); + check_ranges::(&db, &[-32769, 32768], &[-32768, -1, 0, 1, 32767]); + check_ranges::( + &db, + &[-2_147_483_649, 2_147_483_648], + &[-2_147_483_648, -1, 0, 1, 2_147_483_647], + ); + check_ranges::(&db, &[-2, -1, 256], &[0, 1, 255]); + check_ranges::(&db, &[-2, -1, 65536], &[0, 1, 65535]); + check_ranges::(&db, &[-2, -1, 4_294_967_296], &[0, 1, 4_294_967_295]); + Ok(()) + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs new file mode 100644 index 0000000..4000ae2 --- /dev/null +++ b/src/types/mod.rs @@ -0,0 +1,449 @@ +//! Traits dealing with SQLite data types. +//! +//! SQLite uses a [dynamic type system](https://www.sqlite.org/datatype3.html). Implementations of +//! the [`ToSql`] and [`FromSql`] traits are provided for the basic types that +//! SQLite provides methods for: +//! +//! * Strings (`String` and `&str`) +//! * Blobs (`Vec` and `&[u8]`) +//! * Numbers +//! +//! The number situation is a little complicated due to the fact that all +//! numbers in SQLite are stored as `INTEGER` (`i64`) or `REAL` (`f64`). +//! +//! [`ToSql`] and [`FromSql`] are implemented for all primitive number types. +//! [`FromSql`] has different behaviour depending on the SQL and Rust types, and +//! the value. +//! +//! * `INTEGER` to integer: returns an +//! [`Error::IntegralValueOutOfRange`](crate::Error::IntegralValueOutOfRange) +//! error if the value does not fit in the Rust type. +//! * `REAL` to integer: always returns an +//! [`Error::InvalidColumnType`](crate::Error::InvalidColumnType) error. +//! * `INTEGER` to float: casts using `as` operator. Never fails. +//! * `REAL` to float: casts using `as` operator. Never fails. +//! +//! [`ToSql`] always succeeds except when storing a `u64` or `usize` value that +//! cannot fit in an `INTEGER` (`i64`). Also note that SQLite ignores column +//! types, so if you store an `i64` in a column with type `REAL` it will be +//! stored as an `INTEGER`, not a `REAL`. +//! +//! If the `time` feature is enabled, implementations are +//! provided for `time::OffsetDateTime` that use the RFC 3339 date/time format, +//! `"%Y-%m-%dT%H:%M:%S.%fZ"`, to store time values as strings. These values +//! can be parsed by SQLite's builtin +//! [datetime](https://www.sqlite.org/lang_datefunc.html) functions. If you +//! want different storage for datetimes, you can use a newtype. +#![cfg_attr( + feature = "time", + doc = r##" +For example, to store datetimes as `i64`s counting the number of seconds since +the Unix epoch: + +``` +use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use rusqlite::Result; + +pub struct DateTimeSql(pub time::OffsetDateTime); + +impl FromSql for DateTimeSql { + fn column_result(value: ValueRef) -> FromSqlResult { + i64::column_result(value).and_then(|as_i64| { + time::OffsetDateTime::from_unix_timestamp(as_i64) + .map(|odt| DateTimeSql(odt)) + .map_err(|err| FromSqlError::Other(Box::new(err))) + }) + } +} + +impl ToSql for DateTimeSql { + fn to_sql(&self) -> Result { + Ok(self.0.unix_timestamp().into()) + } +} +``` + +"## +)] +//! [`ToSql`] and [`FromSql`] are also implemented for `Option` where `T` +//! implements [`ToSql`] or [`FromSql`] for the cases where you want to know if +//! a value was NULL (which gets translated to `None`). + +pub use self::from_sql::{FromSql, FromSqlError, FromSqlResult}; +pub use self::to_sql::{ToSql, ToSqlOutput}; +pub use self::value::Value; +pub use self::value_ref::ValueRef; + +use std::fmt; + +#[cfg(feature = "chrono")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono")))] +mod chrono; +mod from_sql; +#[cfg(feature = "serde_json")] +#[cfg_attr(docsrs, doc(cfg(feature = "serde_json")))] +mod serde_json; +#[cfg(feature = "time")] +#[cfg_attr(docsrs, doc(cfg(feature = "time")))] +mod time; +mod to_sql; +#[cfg(feature = "url")] +#[cfg_attr(docsrs, doc(cfg(feature = "url")))] +mod url; +mod value; +mod value_ref; + +/// Empty struct that can be used to fill in a query parameter as `NULL`. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # use rusqlite::types::{Null}; +/// +/// fn insert_null(conn: &Connection) -> Result { +/// conn.execute("INSERT INTO people (name) VALUES (?)", [Null]) +/// } +/// ``` +#[derive(Copy, Clone)] +pub struct Null; + +/// SQLite data types. +/// See [Fundamental Datatypes](https://sqlite.org/c3ref/c_blob.html). +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Type { + /// NULL + Null, + /// 64-bit signed integer + Integer, + /// 64-bit IEEE floating point number + Real, + /// String + Text, + /// BLOB + Blob, +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Type::Null => f.pad("Null"), + Type::Integer => f.pad("Integer"), + Type::Real => f.pad("Real"), + Type::Text => f.pad("Text"), + Type::Blob => f.pad("Blob"), + } + } +} + +#[cfg(test)] +mod test { + use super::Value; + use crate::{params, Connection, Error, Result, Statement}; + use std::os::raw::{c_double, c_int}; + + fn checked_memory_handle() -> Result { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (b BLOB, t TEXT, i INTEGER, f FLOAT, n)")?; + Ok(db) + } + + #[test] + fn test_blob() -> Result<()> { + let db = checked_memory_handle()?; + + let v1234 = vec![1u8, 2, 3, 4]; + db.execute("INSERT INTO foo(b) VALUES (?)", &[&v1234])?; + + let v: Vec = db.query_row("SELECT b FROM foo", [], |r| r.get(0))?; + assert_eq!(v, v1234); + Ok(()) + } + + #[test] + fn test_empty_blob() -> Result<()> { + let db = checked_memory_handle()?; + + let empty = vec![]; + db.execute("INSERT INTO foo(b) VALUES (?)", &[&empty])?; + + let v: Vec = db.query_row("SELECT b FROM foo", [], |r| r.get(0))?; + assert_eq!(v, empty); + Ok(()) + } + + #[test] + fn test_str() -> Result<()> { + let db = checked_memory_handle()?; + + let s = "hello, world!"; + db.execute("INSERT INTO foo(t) VALUES (?)", &[&s])?; + + let from: String = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(from, s); + Ok(()) + } + + #[test] + fn test_string() -> Result<()> { + let db = checked_memory_handle()?; + + let s = "hello, world!"; + db.execute("INSERT INTO foo(t) VALUES (?)", [s.to_owned()])?; + + let from: String = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(from, s); + Ok(()) + } + + #[test] + fn test_value() -> Result<()> { + let db = checked_memory_handle()?; + + db.execute("INSERT INTO foo(i) VALUES (?)", [Value::Integer(10)])?; + + assert_eq!( + 10i64, + db.query_row::("SELECT i FROM foo", [], |r| r.get(0))? + ); + Ok(()) + } + + #[test] + fn test_option() -> Result<()> { + let db = checked_memory_handle()?; + + let s = Some("hello, world!"); + let b = Some(vec![1u8, 2, 3, 4]); + + db.execute("INSERT INTO foo(t) VALUES (?)", &[&s])?; + db.execute("INSERT INTO foo(b) VALUES (?)", &[&b])?; + + let mut stmt = db.prepare("SELECT t, b FROM foo ORDER BY ROWID ASC")?; + let mut rows = stmt.query([])?; + + { + let row1 = rows.next()?.unwrap(); + let s1: Option = row1.get_unwrap(0); + let b1: Option> = row1.get_unwrap(1); + assert_eq!(s.unwrap(), s1.unwrap()); + assert!(b1.is_none()); + } + + { + let row2 = rows.next()?.unwrap(); + let s2: Option = row2.get_unwrap(0); + let b2: Option> = row2.get_unwrap(1); + assert!(s2.is_none()); + assert_eq!(b, b2); + } + Ok(()) + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_mismatched_types() -> Result<()> { + fn is_invalid_column_type(err: Error) -> bool { + matches!(err, Error::InvalidColumnType(..)) + } + + let db = checked_memory_handle()?; + + db.execute( + "INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", + [], + )?; + + let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo")?; + let mut rows = stmt.query([])?; + + let row = rows.next()?.unwrap(); + + // check the correct types come back as expected + assert_eq!(vec![1, 2], row.get::<_, Vec>(0)?); + assert_eq!("text", row.get::<_, String>(1)?); + assert_eq!(1, row.get::<_, c_int>(2)?); + assert!((1.5 - row.get::<_, c_double>(3)?).abs() < f64::EPSILON); + assert_eq!(row.get::<_, Option>(4)?, None); + assert_eq!(row.get::<_, Option>(4)?, None); + assert_eq!(row.get::<_, Option>(4)?, None); + + // check some invalid types + + // 0 is actually a blob (Vec) + assert!(is_invalid_column_type(row.get::<_, c_int>(0).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, c_int>(0).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, i64>(0).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(0).unwrap_err() + )); + assert!(is_invalid_column_type(row.get::<_, String>(0).unwrap_err())); + #[cfg(feature = "time")] + assert!(is_invalid_column_type( + row.get::<_, time::OffsetDateTime>(0).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Option>(0).unwrap_err() + )); + + // 1 is actually a text (String) + assert!(is_invalid_column_type(row.get::<_, c_int>(1).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, i64>(1).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(1).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Vec>(1).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Option>(1).unwrap_err() + )); + + // 2 is actually an integer + assert!(is_invalid_column_type(row.get::<_, String>(2).unwrap_err())); + assert!(is_invalid_column_type( + row.get::<_, Vec>(2).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Option>(2).unwrap_err() + )); + + // 3 is actually a float (c_double) + assert!(is_invalid_column_type(row.get::<_, c_int>(3).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, i64>(3).err().unwrap())); + assert!(is_invalid_column_type(row.get::<_, String>(3).unwrap_err())); + assert!(is_invalid_column_type( + row.get::<_, Vec>(3).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Option>(3).unwrap_err() + )); + + // 4 is actually NULL + assert!(is_invalid_column_type(row.get::<_, c_int>(4).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, i64>(4).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(4).unwrap_err() + )); + assert!(is_invalid_column_type(row.get::<_, String>(4).unwrap_err())); + assert!(is_invalid_column_type( + row.get::<_, Vec>(4).unwrap_err() + )); + #[cfg(feature = "time")] + assert!(is_invalid_column_type( + row.get::<_, time::OffsetDateTime>(4).unwrap_err() + )); + Ok(()) + } + + #[test] + fn test_dynamic_type() -> Result<()> { + use super::Value; + let db = checked_memory_handle()?; + + db.execute( + "INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", + [], + )?; + + let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo")?; + let mut rows = stmt.query([])?; + + let row = rows.next()?.unwrap(); + assert_eq!(Value::Blob(vec![1, 2]), row.get::<_, Value>(0)?); + assert_eq!(Value::Text(String::from("text")), row.get::<_, Value>(1)?); + assert_eq!(Value::Integer(1), row.get::<_, Value>(2)?); + match row.get::<_, Value>(3)? { + Value::Real(val) => assert!((1.5 - val).abs() < f64::EPSILON), + x => panic!("Invalid Value {:?}", x), + } + assert_eq!(Value::Null, row.get::<_, Value>(4)?); + Ok(()) + } + + macro_rules! test_conversion { + ($db_etc:ident, $insert_value:expr, $get_type:ty,expect $expected_value:expr) => { + $db_etc.insert_statement.execute(params![$insert_value])?; + let res = $db_etc + .query_statement + .query_row([], |row| row.get::<_, $get_type>(0)); + assert_eq!(res?, $expected_value); + $db_etc.delete_statement.execute([])?; + }; + ($db_etc:ident, $insert_value:expr, $get_type:ty,expect_from_sql_error) => { + $db_etc.insert_statement.execute(params![$insert_value])?; + let res = $db_etc + .query_statement + .query_row([], |row| row.get::<_, $get_type>(0)); + res.unwrap_err(); + $db_etc.delete_statement.execute([])?; + }; + ($db_etc:ident, $insert_value:expr, $get_type:ty,expect_to_sql_error) => { + $db_etc + .insert_statement + .execute(params![$insert_value]) + .unwrap_err(); + }; + } + + #[test] + fn test_numeric_conversions() -> Result<()> { + #![allow(clippy::float_cmp)] + + // Test what happens when we store an f32 and retrieve an i32 etc. + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (x)")?; + + // SQLite actually ignores the column types, so we just need to test + // different numeric values. + + struct DbEtc<'conn> { + insert_statement: Statement<'conn>, + query_statement: Statement<'conn>, + delete_statement: Statement<'conn>, + } + + let mut db_etc = DbEtc { + insert_statement: db.prepare("INSERT INTO foo VALUES (?1)")?, + query_statement: db.prepare("SELECT x FROM foo")?, + delete_statement: db.prepare("DELETE FROM foo")?, + }; + + // Basic non-converting test. + test_conversion!(db_etc, 0u8, u8, expect 0u8); + + // In-range integral conversions. + test_conversion!(db_etc, 100u8, i8, expect 100i8); + test_conversion!(db_etc, 200u8, u8, expect 200u8); + test_conversion!(db_etc, 100u16, i8, expect 100i8); + test_conversion!(db_etc, 200u16, u8, expect 200u8); + test_conversion!(db_etc, u32::MAX, u64, expect u32::MAX as u64); + test_conversion!(db_etc, i64::MIN, i64, expect i64::MIN); + test_conversion!(db_etc, i64::MAX, i64, expect i64::MAX); + test_conversion!(db_etc, i64::MAX, u64, expect i64::MAX as u64); + test_conversion!(db_etc, 100usize, usize, expect 100usize); + test_conversion!(db_etc, 100u64, u64, expect 100u64); + test_conversion!(db_etc, i64::MAX as u64, u64, expect i64::MAX as u64); + + // Out-of-range integral conversions. + test_conversion!(db_etc, 200u8, i8, expect_from_sql_error); + test_conversion!(db_etc, 400u16, i8, expect_from_sql_error); + test_conversion!(db_etc, 400u16, u8, expect_from_sql_error); + test_conversion!(db_etc, -1i8, u8, expect_from_sql_error); + test_conversion!(db_etc, i64::MIN, u64, expect_from_sql_error); + test_conversion!(db_etc, u64::MAX, i64, expect_to_sql_error); + test_conversion!(db_etc, u64::MAX, u64, expect_to_sql_error); + test_conversion!(db_etc, i64::MAX as u64 + 1, u64, expect_to_sql_error); + + // FromSql integer to float, always works. + test_conversion!(db_etc, i64::MIN, f32, expect i64::MIN as f32); + test_conversion!(db_etc, i64::MAX, f32, expect i64::MAX as f32); + test_conversion!(db_etc, i64::MIN, f64, expect i64::MIN as f64); + test_conversion!(db_etc, i64::MAX, f64, expect i64::MAX as f64); + + // FromSql float to int conversion, never works even if the actual value + // is an integer. + test_conversion!(db_etc, 0f64, i64, expect_from_sql_error); + Ok(()) + } +} diff --git a/src/types/serde_json.rs b/src/types/serde_json.rs new file mode 100644 index 0000000..a9761bd --- /dev/null +++ b/src/types/serde_json.rs @@ -0,0 +1,53 @@ +//! [`ToSql`] and [`FromSql`] implementation for JSON `Value`. + +use serde_json::Value; + +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; + +/// Serialize JSON `Value` to text. +impl ToSql for Value { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(serde_json::to_string(self).unwrap())) + } +} + +/// Deserialize text/blob to JSON `Value`. +impl FromSql for Value { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + let bytes = value.as_bytes()?; + serde_json::from_slice(bytes).map_err(|err| FromSqlError::Other(Box::new(err))) + } +} + +#[cfg(test)] +mod test { + use crate::types::ToSql; + use crate::{Connection, Result}; + + fn checked_memory_handle() -> Result { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (t TEXT, b BLOB)")?; + Ok(db) + } + + #[test] + fn test_json_value() -> Result<()> { + let db = checked_memory_handle()?; + + let json = r#"{"foo": 13, "bar": "baz"}"#; + let data: serde_json::Value = serde_json::from_str(json).unwrap(); + db.execute( + "INSERT INTO foo (t, b) VALUES (?, ?)", + &[&data as &dyn ToSql, &json.as_bytes()], + )?; + + let t: serde_json::Value = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + assert_eq!(data, t); + let b: serde_json::Value = db.query_row("SELECT b FROM foo", [], |r| r.get(0))?; + assert_eq!(data, b); + Ok(()) + } +} diff --git a/src/types/time.rs b/src/types/time.rs new file mode 100644 index 0000000..4e2811e --- /dev/null +++ b/src/types/time.rs @@ -0,0 +1,168 @@ +//! [`ToSql`] and [`FromSql`] implementation for [`time::OffsetDateTime`]. +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::{Error, Result}; +use time::format_description::well_known::Rfc3339; +use time::format_description::FormatItem; +use time::macros::format_description; +use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; + +const PRIMITIVE_SHORT_DATE_TIME_FORMAT: &[FormatItem<'_>] = + format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); +const PRIMITIVE_DATE_TIME_FORMAT: &[FormatItem<'_>] = + format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]"); +const PRIMITIVE_DATE_TIME_Z_FORMAT: &[FormatItem<'_>] = + format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]Z"); +const OFFSET_SHORT_DATE_TIME_FORMAT: &[FormatItem<'_>] = format_description!( + "[year]-[month]-[day] [hour]:[minute]:[second][offset_hour sign:mandatory]:[offset_minute]" +); +const OFFSET_DATE_TIME_FORMAT: &[FormatItem<'_>] = format_description!( + "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond][offset_hour sign:mandatory]:[offset_minute]" +); +const LEGACY_DATE_TIME_FORMAT: &[FormatItem<'_>] = format_description!( + "[year]-[month]-[day] [hour]:[minute]:[second]:[subsecond] [offset_hour sign:mandatory]:[offset_minute]" +); + +impl ToSql for OffsetDateTime { + #[inline] + fn to_sql(&self) -> Result> { + // FIXME keep original offset + let time_string = self + .to_offset(UtcOffset::UTC) + .format(&PRIMITIVE_DATE_TIME_Z_FORMAT) + .map_err(|err| Error::ToSqlConversionFailure(err.into()))?; + Ok(ToSqlOutput::from(time_string)) + } +} + +impl FromSql for OffsetDateTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + value.as_str().and_then(|s| { + if s.len() > 10 && s.as_bytes()[10] == b'T' { + // YYYY-MM-DDTHH:MM:SS.SSS[+-]HH:MM + return OffsetDateTime::parse(s, &Rfc3339) + .map_err(|err| FromSqlError::Other(Box::new(err))); + } + let s = s.strip_suffix('Z').unwrap_or(s); + match s.len() { + len if len <= 19 => { + // TODO YYYY-MM-DDTHH:MM:SS + PrimitiveDateTime::parse(s, &PRIMITIVE_SHORT_DATE_TIME_FORMAT) + .map(PrimitiveDateTime::assume_utc) + } + _ if s.as_bytes()[19] == b':' => { + // legacy + OffsetDateTime::parse(s, &LEGACY_DATE_TIME_FORMAT) + } + _ if s.as_bytes()[19] == b'.' => OffsetDateTime::parse(s, &OFFSET_DATE_TIME_FORMAT) + .or_else(|err| { + PrimitiveDateTime::parse(s, &PRIMITIVE_DATE_TIME_FORMAT) + .map(PrimitiveDateTime::assume_utc) + .map_err(|_| err) + }), + _ => OffsetDateTime::parse(s, &OFFSET_SHORT_DATE_TIME_FORMAT), + } + .map_err(|err| FromSqlError::Other(Box::new(err))) + }) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result}; + use time::format_description::well_known::Rfc3339; + use time::OffsetDateTime; + + #[test] + fn test_offset_date_time() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (t TEXT, i INTEGER, f FLOAT)")?; + + let mut ts_vec = vec![]; + + let make_datetime = |secs: i128, nanos: i128| { + OffsetDateTime::from_unix_timestamp_nanos(1_000_000_000 * secs + nanos).unwrap() + }; + + ts_vec.push(make_datetime(10_000, 0)); //January 1, 1970 2:46:40 AM + ts_vec.push(make_datetime(10_000, 1000)); //January 1, 1970 2:46:40 AM (and one microsecond) + ts_vec.push(make_datetime(1_500_391_124, 1_000_000)); //July 18, 2017 + ts_vec.push(make_datetime(2_000_000_000, 2_000_000)); //May 18, 2033 + ts_vec.push(make_datetime(3_000_000_000, 999_999_999)); //January 24, 2065 + ts_vec.push(make_datetime(10_000_000_000, 0)); //November 20, 2286 + + for ts in ts_vec { + db.execute("INSERT INTO foo(t) VALUES (?)", [ts])?; + + let from: OffsetDateTime = db.query_row("SELECT t FROM foo", [], |r| r.get(0))?; + + db.execute("DELETE FROM foo", [])?; + + assert_eq!(from, ts); + } + Ok(()) + } + + #[test] + fn test_string_values() -> Result<()> { + let db = Connection::open_in_memory()?; + for (s, t) in vec![ + ( + "2013-10-07 08:23:19", + Ok(OffsetDateTime::parse("2013-10-07T08:23:19Z", &Rfc3339).unwrap()), + ), + ( + "2013-10-07 08:23:19Z", + Ok(OffsetDateTime::parse("2013-10-07T08:23:19Z", &Rfc3339).unwrap()), + ), + ( + "2013-10-07T08:23:19Z", + Ok(OffsetDateTime::parse("2013-10-07T08:23:19Z", &Rfc3339).unwrap()), + ), + ( + "2013-10-07 08:23:19.120", + Ok(OffsetDateTime::parse("2013-10-07T08:23:19.120Z", &Rfc3339).unwrap()), + ), + ( + "2013-10-07 08:23:19.120Z", + Ok(OffsetDateTime::parse("2013-10-07T08:23:19.120Z", &Rfc3339).unwrap()), + ), + ( + "2013-10-07T08:23:19.120Z", + Ok(OffsetDateTime::parse("2013-10-07T08:23:19.120Z", &Rfc3339).unwrap()), + ), + ( + "2013-10-07 04:23:19-04:00", + Ok(OffsetDateTime::parse("2013-10-07T04:23:19-04:00", &Rfc3339).unwrap()), + ), + ( + "2013-10-07 04:23:19.120-04:00", + Ok(OffsetDateTime::parse("2013-10-07T04:23:19.120-04:00", &Rfc3339).unwrap()), + ), + ( + "2013-10-07T04:23:19.120-04:00", + Ok(OffsetDateTime::parse("2013-10-07T04:23:19.120-04:00", &Rfc3339).unwrap()), + ), + ] { + let result: Result = db.query_row("SELECT ?", [s], |r| r.get(0)); + assert_eq!(result, t); + } + Ok(()) + } + + #[test] + fn test_sqlite_functions() -> Result<()> { + let db = Connection::open_in_memory()?; + let result: Result = + db.query_row("SELECT CURRENT_TIMESTAMP", [], |r| r.get(0)); + assert!(result.is_ok()); + Ok(()) + } + + #[test] + fn test_param() -> Result<()> { + let db = Connection::open_in_memory()?; + let result: Result = db.query_row("SELECT 1 WHERE ? BETWEEN datetime('now', '-1 minute') AND datetime('now', '+1 minute')", [OffsetDateTime::now_utc()], |r| r.get(0)); + assert!(result.is_ok()); + Ok(()) + } +} diff --git a/src/types/to_sql.rs b/src/types/to_sql.rs new file mode 100644 index 0000000..4e0d882 --- /dev/null +++ b/src/types/to_sql.rs @@ -0,0 +1,429 @@ +use super::{Null, Value, ValueRef}; +#[cfg(feature = "array")] +use crate::vtab::array::Array; +use crate::{Error, Result}; +use std::borrow::Cow; +use std::convert::TryFrom; + +/// `ToSqlOutput` represents the possible output types for implementers of the +/// [`ToSql`] trait. +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum ToSqlOutput<'a> { + /// A borrowed SQLite-representable value. + Borrowed(ValueRef<'a>), + + /// An owned SQLite-representable value. + Owned(Value), + + /// A BLOB of the given length that is filled with + /// zeroes. + #[cfg(feature = "blob")] + #[cfg_attr(docsrs, doc(cfg(feature = "blob")))] + ZeroBlob(i32), + + /// `feature = "array"` + #[cfg(feature = "array")] + #[cfg_attr(docsrs, doc(cfg(feature = "array")))] + Array(Array), +} + +// Generically allow any type that can be converted into a ValueRef +// to be converted into a ToSqlOutput as well. +impl<'a, T: ?Sized> From<&'a T> for ToSqlOutput<'a> +where + &'a T: Into>, +{ + #[inline] + fn from(t: &'a T) -> Self { + ToSqlOutput::Borrowed(t.into()) + } +} + +// We cannot also generically allow any type that can be converted +// into a Value to be converted into a ToSqlOutput because of +// coherence rules (https://github.com/rust-lang/rust/pull/46192), +// so we'll manually implement it for all the types we know can +// be converted into Values. +macro_rules! from_value( + ($t:ty) => ( + impl From<$t> for ToSqlOutput<'_> { + #[inline] + fn from(t: $t) -> Self { ToSqlOutput::Owned(t.into())} + } + ) +); +from_value!(String); +from_value!(Null); +from_value!(bool); +from_value!(i8); +from_value!(i16); +from_value!(i32); +from_value!(i64); +from_value!(isize); +from_value!(u8); +from_value!(u16); +from_value!(u32); +from_value!(f32); +from_value!(f64); +from_value!(Vec); + +// It would be nice if we could avoid the heap allocation (of the `Vec`) that +// `i128` needs in `Into`, but it's probably fine for the moment, and not +// worth adding another case to Value. +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +from_value!(i128); + +#[cfg(feature = "uuid")] +#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] +from_value!(uuid::Uuid); + +impl ToSql for ToSqlOutput<'_> { + #[inline] + fn to_sql(&self) -> Result> { + Ok(match *self { + ToSqlOutput::Borrowed(v) => ToSqlOutput::Borrowed(v), + ToSqlOutput::Owned(ref v) => ToSqlOutput::Borrowed(ValueRef::from(v)), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(i) => ToSqlOutput::ZeroBlob(i), + #[cfg(feature = "array")] + ToSqlOutput::Array(ref a) => ToSqlOutput::Array(a.clone()), + }) + } +} + +/// A trait for types that can be converted into SQLite values. Returns +/// [`Error::ToSqlConversionFailure`] if the conversion fails. +pub trait ToSql { + /// Converts Rust value to SQLite value + fn to_sql(&self) -> Result>; +} + +impl ToSql for Cow<'_, T> { + #[inline] + fn to_sql(&self) -> Result> { + self.as_ref().to_sql() + } +} + +impl ToSql for Box { + #[inline] + fn to_sql(&self) -> Result> { + self.as_ref().to_sql() + } +} + +impl ToSql for std::rc::Rc { + #[inline] + fn to_sql(&self) -> Result> { + self.as_ref().to_sql() + } +} + +impl ToSql for std::sync::Arc { + #[inline] + fn to_sql(&self) -> Result> { + self.as_ref().to_sql() + } +} + +// We should be able to use a generic impl like this: +// +// impl ToSql for T where T: Into { +// fn to_sql(&self) -> Result { +// Ok(ToSqlOutput::from((*self).into())) +// } +// } +// +// instead of the following macro, but this runs afoul of +// https://github.com/rust-lang/rust/issues/30191 and reports conflicting +// implementations even when there aren't any. + +macro_rules! to_sql_self( + ($t:ty) => ( + impl ToSql for $t { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(*self)) + } + } + ) +); + +to_sql_self!(Null); +to_sql_self!(bool); +to_sql_self!(i8); +to_sql_self!(i16); +to_sql_self!(i32); +to_sql_self!(i64); +to_sql_self!(isize); +to_sql_self!(u8); +to_sql_self!(u16); +to_sql_self!(u32); +to_sql_self!(f32); +to_sql_self!(f64); + +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +to_sql_self!(i128); + +#[cfg(feature = "uuid")] +#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] +to_sql_self!(uuid::Uuid); + +macro_rules! to_sql_self_fallible( + ($t:ty) => ( + impl ToSql for $t { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::Owned(Value::Integer( + i64::try_from(*self).map_err( + // TODO: Include the values in the error message. + |err| Error::ToSqlConversionFailure(err.into()) + )? + ))) + } + } + ) +); + +// Special implementations for usize and u64 because these conversions can fail. +to_sql_self_fallible!(u64); +to_sql_self_fallible!(usize); + +impl ToSql for &'_ T +where + T: ToSql, +{ + #[inline] + fn to_sql(&self) -> Result> { + (*self).to_sql() + } +} + +impl ToSql for String { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(self.as_str())) + } +} + +impl ToSql for str { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(self)) + } +} + +impl ToSql for Vec { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(self.as_slice())) + } +} + +impl ToSql for [u8; N] { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(&self[..])) + } +} + +impl ToSql for [u8] { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(self)) + } +} + +impl ToSql for Value { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(self)) + } +} + +impl ToSql for Option { + #[inline] + fn to_sql(&self) -> Result> { + match *self { + None => Ok(ToSqlOutput::from(Null)), + Some(ref t) => t.to_sql(), + } + } +} + +#[cfg(test)] +mod test { + use super::ToSql; + + fn is_to_sql() {} + + #[test] + fn test_integral_types() { + is_to_sql::(); + is_to_sql::(); + is_to_sql::(); + is_to_sql::(); + is_to_sql::(); + is_to_sql::(); + is_to_sql::(); + } + + #[test] + fn test_u8_array() { + let a: [u8; 99] = [0u8; 99]; + let _a: &[&dyn ToSql] = crate::params![a]; + let r = ToSql::to_sql(&a); + + assert!(r.is_ok()); + } + + #[test] + fn test_cow_str() { + use std::borrow::Cow; + let s = "str"; + let cow: Cow = Cow::Borrowed(s); + let r = cow.to_sql(); + assert!(r.is_ok()); + let cow: Cow = Cow::Owned::(String::from(s)); + let r = cow.to_sql(); + assert!(r.is_ok()); + // Ensure this compiles. + let _p: &[&dyn ToSql] = crate::params![cow]; + } + + #[test] + fn test_box_dyn() { + let s: Box = Box::new("Hello world!"); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = ToSql::to_sql(&s); + + assert!(r.is_ok()); + } + + #[test] + fn test_box_deref() { + let s: Box = "Hello world!".into(); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + + assert!(r.is_ok()); + } + + #[test] + fn test_box_direct() { + let s: Box = "Hello world!".into(); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = ToSql::to_sql(&s); + + assert!(r.is_ok()); + } + + #[test] + fn test_cells() { + use std::{rc::Rc, sync::Arc}; + + let source_str: Box = "Hello world!".into(); + + let s: Rc> = Rc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Arc> = Arc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Arc = Arc::from(&*source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Arc = Arc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Rc = Rc::from(&*source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Rc = Rc::new(source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + } + + #[cfg(feature = "i128_blob")] + #[test] + fn test_i128() -> crate::Result<()> { + use crate::Connection; + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (i128 BLOB, desc TEXT)")?; + db.execute( + " + INSERT INTO foo(i128, desc) VALUES + (?, 'zero'), + (?, 'neg one'), (?, 'neg two'), + (?, 'pos one'), (?, 'pos two'), + (?, 'min'), (?, 'max')", + [0i128, -1i128, -2i128, 1i128, 2i128, i128::MIN, i128::MAX], + )?; + + let mut stmt = db.prepare("SELECT i128, desc FROM foo ORDER BY i128 ASC")?; + + let res = stmt + .query_map([], |row| { + Ok((row.get::<_, i128>(0)?, row.get::<_, String>(1)?)) + })? + .collect::, _>>()?; + + assert_eq!( + res, + &[ + (i128::MIN, "min".to_owned()), + (-2, "neg two".to_owned()), + (-1, "neg one".to_owned()), + (0, "zero".to_owned()), + (1, "pos one".to_owned()), + (2, "pos two".to_owned()), + (i128::MAX, "max".to_owned()), + ] + ); + Ok(()) + } + + #[cfg(feature = "uuid")] + #[test] + fn test_uuid() -> crate::Result<()> { + use crate::{params, Connection}; + use uuid::Uuid; + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (id BLOB CHECK(length(id) = 16), label TEXT);")?; + + let id = Uuid::new_v4(); + + db.execute( + "INSERT INTO foo (id, label) VALUES (?, ?)", + params![id, "target"], + )?; + + let mut stmt = db.prepare("SELECT id, label FROM foo WHERE id = ?")?; + + let mut rows = stmt.query(params![id])?; + let row = rows.next()?.unwrap(); + + let found_id: Uuid = row.get_unwrap(0); + let found_label: String = row.get_unwrap(1); + + assert_eq!(found_id, id); + assert_eq!(found_label, "target"); + Ok(()) + } +} diff --git a/src/types/url.rs b/src/types/url.rs new file mode 100644 index 0000000..fea8500 --- /dev/null +++ b/src/types/url.rs @@ -0,0 +1,82 @@ +//! [`ToSql`] and [`FromSql`] implementation for [`url::Url`]. +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; +use url::Url; + +/// Serialize `Url` to text. +impl ToSql for Url { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::from(self.as_str())) + } +} + +/// Deserialize text to `Url`. +impl FromSql for Url { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + match value { + ValueRef::Text(s) => { + let s = std::str::from_utf8(s).map_err(|e| FromSqlError::Other(Box::new(e)))?; + Url::parse(s).map_err(|e| FromSqlError::Other(Box::new(e))) + } + _ => Err(FromSqlError::InvalidType), + } + } +} + +#[cfg(test)] +mod test { + use crate::{params, Connection, Error, Result}; + use url::{ParseError, Url}; + + fn checked_memory_handle() -> Result { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE urls (i INTEGER, v TEXT)")?; + Ok(db) + } + + fn get_url(db: &Connection, id: i64) -> Result { + db.query_row("SELECT v FROM urls WHERE i = ?", [id], |r| r.get(0)) + } + + #[test] + fn test_sql_url() -> Result<()> { + let db = &checked_memory_handle()?; + + let url0 = Url::parse("http://www.example1.com").unwrap(); + let url1 = Url::parse("http://www.example1.com/👌").unwrap(); + let url2 = "http://www.example2.com/👌"; + + db.execute( + "INSERT INTO urls (i, v) VALUES (0, ?), (1, ?), (2, ?), (3, ?)", + // also insert a non-hex encoded url (which might be present if it was + // inserted separately) + params![url0, url1, url2, "illegal"], + )?; + + assert_eq!(get_url(db, 0)?, url0); + + assert_eq!(get_url(db, 1)?, url1); + + // Should successfully read it, even though it wasn't inserted as an + // escaped url. + let out_url2: Url = get_url(db, 2)?; + assert_eq!(out_url2, Url::parse(url2).unwrap()); + + // Make sure the conversion error comes through correctly. + let err = get_url(db, 3).unwrap_err(); + match err { + Error::FromSqlConversionFailure(_, _, e) => { + assert_eq!( + *e.downcast::().unwrap(), + ParseError::RelativeUrlWithoutBase, + ); + } + e => { + panic!("Expected conversion failure, got {}", e); + } + } + Ok(()) + } +} diff --git a/src/types/value.rs b/src/types/value.rs new file mode 100644 index 0000000..ca3ee9f --- /dev/null +++ b/src/types/value.rs @@ -0,0 +1,142 @@ +use super::{Null, Type}; + +/// Owning [dynamic type value](http://sqlite.org/datatype3.html). Value's type is typically +/// dictated by SQLite (not by the caller). +/// +/// See [`ValueRef`](crate::types::ValueRef) for a non-owning dynamic type +/// value. +#[derive(Clone, Debug, PartialEq)] +pub enum Value { + /// The value is a `NULL` value. + Null, + /// The value is a signed integer. + Integer(i64), + /// The value is a floating point number. + Real(f64), + /// The value is a text string. + Text(String), + /// The value is a blob of data + Blob(Vec), +} + +impl From for Value { + #[inline] + fn from(_: Null) -> Value { + Value::Null + } +} + +impl From for Value { + #[inline] + fn from(i: bool) -> Value { + Value::Integer(i as i64) + } +} + +impl From for Value { + #[inline] + fn from(i: isize) -> Value { + Value::Integer(i as i64) + } +} + +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +impl From for Value { + #[inline] + fn from(i: i128) -> Value { + // We store these biased (e.g. with the most significant bit flipped) + // so that comparisons with negative numbers work properly. + Value::Blob(i128::to_be_bytes(i ^ (1_i128 << 127)).to_vec()) + } +} + +#[cfg(feature = "uuid")] +#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] +impl From for Value { + #[inline] + fn from(id: uuid::Uuid) -> Value { + Value::Blob(id.as_bytes().to_vec()) + } +} + +macro_rules! from_i64( + ($t:ty) => ( + impl From<$t> for Value { + #[inline] + fn from(i: $t) -> Value { + Value::Integer(i64::from(i)) + } + } + ) +); + +from_i64!(i8); +from_i64!(i16); +from_i64!(i32); +from_i64!(u8); +from_i64!(u16); +from_i64!(u32); + +impl From for Value { + #[inline] + fn from(i: i64) -> Value { + Value::Integer(i) + } +} + +impl From for Value { + #[inline] + fn from(f: f32) -> Value { + Value::Real(f.into()) + } +} + +impl From for Value { + #[inline] + fn from(f: f64) -> Value { + Value::Real(f) + } +} + +impl From for Value { + #[inline] + fn from(s: String) -> Value { + Value::Text(s) + } +} + +impl From> for Value { + #[inline] + fn from(v: Vec) -> Value { + Value::Blob(v) + } +} + +impl From> for Value +where + T: Into, +{ + #[inline] + fn from(v: Option) -> Value { + match v { + Some(x) => x.into(), + None => Value::Null, + } + } +} + +impl Value { + /// Returns SQLite fundamental datatype. + #[inline] + #[must_use] + pub fn data_type(&self) -> Type { + match *self { + Value::Null => Type::Null, + Value::Integer(_) => Type::Integer, + Value::Real(_) => Type::Real, + Value::Text(_) => Type::Text, + Value::Blob(_) => Type::Blob, + } + } +} diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs new file mode 100644 index 0000000..12806f8 --- /dev/null +++ b/src/types/value_ref.rs @@ -0,0 +1,263 @@ +use super::{Type, Value}; +use crate::types::{FromSqlError, FromSqlResult}; + +/// A non-owning [dynamic type value](http://sqlite.org/datatype3.html). Typically the +/// memory backing this value is owned by SQLite. +/// +/// See [`Value`](Value) for an owning dynamic type value. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ValueRef<'a> { + /// The value is a `NULL` value. + Null, + /// The value is a signed integer. + Integer(i64), + /// The value is a floating point number. + Real(f64), + /// The value is a text string. + Text(&'a [u8]), + /// The value is a blob of data + Blob(&'a [u8]), +} + +impl ValueRef<'_> { + /// Returns SQLite fundamental datatype. + #[inline] + #[must_use] + pub fn data_type(&self) -> Type { + match *self { + ValueRef::Null => Type::Null, + ValueRef::Integer(_) => Type::Integer, + ValueRef::Real(_) => Type::Real, + ValueRef::Text(_) => Type::Text, + ValueRef::Blob(_) => Type::Blob, + } + } +} + +impl<'a> ValueRef<'a> { + /// If `self` is case `Integer`, returns the integral value. Otherwise, + /// returns [`Err(Error::InvalidColumnType)`](crate::Error:: + /// InvalidColumnType). + #[inline] + pub fn as_i64(&self) -> FromSqlResult { + match *self { + ValueRef::Integer(i) => Ok(i), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is case `Integer`, returns the integral value. + /// Otherwise returns [`Err(Error::InvalidColumnType)`](crate::Error:: + /// InvalidColumnType). + #[inline] + pub fn as_i64_or_null(&self) -> FromSqlResult> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Integer(i) => Ok(Some(i)), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Real`, returns the floating point value. Otherwise, + /// returns [`Err(Error::InvalidColumnType)`](crate::Error:: + /// InvalidColumnType). + #[inline] + pub fn as_f64(&self) -> FromSqlResult { + match *self { + ValueRef::Real(f) => Ok(f), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is case `Real`, returns the floating point value. + /// Otherwise returns [`Err(Error::InvalidColumnType)`](crate::Error:: + /// InvalidColumnType). + #[inline] + pub fn as_f64_or_null(&self) -> FromSqlResult> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Real(f) => Ok(Some(f)), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Text`, returns the string value. Otherwise, returns + /// [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_str(&self) -> FromSqlResult<&'a str> { + match *self { + ValueRef::Text(t) => { + std::str::from_utf8(t).map_err(|e| FromSqlError::Other(Box::new(e))) + } + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is case `Text`, returns the string value. + /// Otherwise returns [`Err(Error::InvalidColumnType)`](crate::Error:: + /// InvalidColumnType). + #[inline] + pub fn as_str_or_null(&self) -> FromSqlResult> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Text(t) => std::str::from_utf8(t) + .map_err(|e| FromSqlError::Other(Box::new(e))) + .map(Some), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Blob`, returns the byte slice. Otherwise, returns + /// [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_blob(&self) -> FromSqlResult<&'a [u8]> { + match *self { + ValueRef::Blob(b) => Ok(b), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is case `Blob`, returns the byte slice. + /// Otherwise returns [`Err(Error::InvalidColumnType)`](crate::Error:: + /// InvalidColumnType). + #[inline] + pub fn as_blob_or_null(&self) -> FromSqlResult> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Blob(b) => Ok(Some(b)), + _ => Err(FromSqlError::InvalidType), + } + } + + /// Returns the byte slice that makes up this ValueRef if it's either + /// [`ValueRef::Blob`] or [`ValueRef::Text`]. + #[inline] + pub fn as_bytes(&self) -> FromSqlResult<&'a [u8]> { + match self { + ValueRef::Text(s) | ValueRef::Blob(s) => Ok(s), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is [`ValueRef::Blob`] or [`ValueRef::Text`] returns the byte + /// slice that makes up this value + #[inline] + pub fn as_bytes_or_null(&self) -> FromSqlResult> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Text(s) | ValueRef::Blob(s) => Ok(Some(s)), + _ => Err(FromSqlError::InvalidType), + } + } +} + +impl From> for Value { + #[inline] + fn from(borrowed: ValueRef<'_>) -> Value { + match borrowed { + ValueRef::Null => Value::Null, + ValueRef::Integer(i) => Value::Integer(i), + ValueRef::Real(r) => Value::Real(r), + ValueRef::Text(s) => { + let s = std::str::from_utf8(s).expect("invalid UTF-8"); + Value::Text(s.to_string()) + } + ValueRef::Blob(b) => Value::Blob(b.to_vec()), + } + } +} + +impl<'a> From<&'a str> for ValueRef<'a> { + #[inline] + fn from(s: &str) -> ValueRef<'_> { + ValueRef::Text(s.as_bytes()) + } +} + +impl<'a> From<&'a [u8]> for ValueRef<'a> { + #[inline] + fn from(s: &[u8]) -> ValueRef<'_> { + ValueRef::Blob(s) + } +} + +impl<'a> From<&'a Value> for ValueRef<'a> { + #[inline] + fn from(value: &'a Value) -> ValueRef<'a> { + match *value { + Value::Null => ValueRef::Null, + Value::Integer(i) => ValueRef::Integer(i), + Value::Real(r) => ValueRef::Real(r), + Value::Text(ref s) => ValueRef::Text(s.as_bytes()), + Value::Blob(ref b) => ValueRef::Blob(b), + } + } +} + +impl<'a, T> From> for ValueRef<'a> +where + T: Into>, +{ + #[inline] + fn from(s: Option) -> ValueRef<'a> { + match s { + Some(x) => x.into(), + None => ValueRef::Null, + } + } +} + +#[cfg(any(feature = "functions", feature = "session", feature = "vtab"))] +impl<'a> ValueRef<'a> { + pub(crate) unsafe fn from_value(value: *mut crate::ffi::sqlite3_value) -> ValueRef<'a> { + use crate::ffi; + use std::slice::from_raw_parts; + + match ffi::sqlite3_value_type(value) { + ffi::SQLITE_NULL => ValueRef::Null, + ffi::SQLITE_INTEGER => ValueRef::Integer(ffi::sqlite3_value_int64(value)), + ffi::SQLITE_FLOAT => ValueRef::Real(ffi::sqlite3_value_double(value)), + ffi::SQLITE_TEXT => { + let text = ffi::sqlite3_value_text(value); + let len = ffi::sqlite3_value_bytes(value); + assert!( + !text.is_null(), + "unexpected SQLITE_TEXT value type with NULL data" + ); + let s = from_raw_parts(text.cast::(), len as usize); + ValueRef::Text(s) + } + ffi::SQLITE_BLOB => { + let (blob, len) = ( + ffi::sqlite3_value_blob(value), + ffi::sqlite3_value_bytes(value), + ); + + assert!( + len >= 0, + "unexpected negative return from sqlite3_value_bytes" + ); + if len > 0 { + assert!( + !blob.is_null(), + "unexpected SQLITE_BLOB value type with NULL data" + ); + ValueRef::Blob(from_raw_parts(blob.cast::(), len as usize)) + } else { + // The return value from sqlite3_value_blob() for a zero-length BLOB + // is a NULL pointer. + ValueRef::Blob(&[]) + } + } + _ => unreachable!("sqlite3_value_type returned invalid value"), + } + } + + // TODO sqlite3_value_nochange // 3.22.0 & VTab xUpdate + // TODO sqlite3_value_frombind // 3.28.0 +} diff --git a/src/unlock_notify.rs b/src/unlock_notify.rs new file mode 100644 index 0000000..8fba6b3 --- /dev/null +++ b/src/unlock_notify.rs @@ -0,0 +1,117 @@ +//! [Unlock Notification](http://sqlite.org/unlock_notify.html) + +use std::os::raw::c_int; +use std::os::raw::c_void; +use std::panic::catch_unwind; +use std::sync::{Condvar, Mutex}; + +use crate::ffi; + +struct UnlockNotification { + cond: Condvar, // Condition variable to wait on + mutex: Mutex, // Mutex to protect structure +} + +#[allow(clippy::mutex_atomic)] +impl UnlockNotification { + fn new() -> UnlockNotification { + UnlockNotification { + cond: Condvar::new(), + mutex: Mutex::new(false), + } + } + + fn fired(&self) { + let mut flag = unpoison(self.mutex.lock()); + *flag = true; + self.cond.notify_one(); + } + + fn wait(&self) { + let mut fired = unpoison(self.mutex.lock()); + while !*fired { + fired = unpoison(self.cond.wait(fired)); + } + } +} + +#[inline] +fn unpoison(r: Result>) -> T { + r.unwrap_or_else(std::sync::PoisonError::into_inner) +} + +/// This function is an unlock-notify callback +unsafe extern "C" fn unlock_notify_cb(ap_arg: *mut *mut c_void, n_arg: c_int) { + use std::slice::from_raw_parts; + let args = from_raw_parts(ap_arg as *const &UnlockNotification, n_arg as usize); + for un in args { + drop(catch_unwind(std::panic::AssertUnwindSafe(|| un.fired()))); + } +} + +pub unsafe fn is_locked(db: *mut ffi::sqlite3, rc: c_int) -> bool { + rc == ffi::SQLITE_LOCKED_SHAREDCACHE + || (rc & 0xFF) == ffi::SQLITE_LOCKED + && ffi::sqlite3_extended_errcode(db) == ffi::SQLITE_LOCKED_SHAREDCACHE +} + +/// This function assumes that an SQLite API call (either `sqlite3_prepare_v2()` +/// or `sqlite3_step()`) has just returned `SQLITE_LOCKED`. The argument is the +/// associated database connection. +/// +/// This function calls `sqlite3_unlock_notify()` to register for an +/// unlock-notify callback, then blocks until that callback is delivered +/// and returns `SQLITE_OK`. The caller should then retry the failed operation. +/// +/// Or, if `sqlite3_unlock_notify()` indicates that to block would deadlock +/// the system, then this function returns `SQLITE_LOCKED` immediately. In +/// this case the caller should not retry the operation and should roll +/// back the current transaction (if any). +#[cfg(feature = "unlock_notify")] +pub unsafe fn wait_for_unlock_notify(db: *mut ffi::sqlite3) -> c_int { + let un = UnlockNotification::new(); + /* Register for an unlock-notify callback. */ + let rc = ffi::sqlite3_unlock_notify( + db, + Some(unlock_notify_cb), + &un as *const UnlockNotification as *mut c_void, + ); + debug_assert!( + rc == ffi::SQLITE_LOCKED || rc == ffi::SQLITE_LOCKED_SHAREDCACHE || rc == ffi::SQLITE_OK + ); + if rc == ffi::SQLITE_OK { + un.wait(); + } + rc +} + +#[cfg(test)] +mod test { + use crate::{Connection, OpenFlags, Result, Transaction, TransactionBehavior}; + use std::sync::mpsc::sync_channel; + use std::thread; + use std::time; + + #[test] + fn test_unlock_notify() -> Result<()> { + let url = "file::memory:?cache=shared"; + let flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_URI; + let db1 = Connection::open_with_flags(url, flags)?; + db1.execute_batch("CREATE TABLE foo (x)")?; + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db2 = Connection::open_with_flags(url, flags).unwrap(); + let tx2 = Transaction::new(&mut db2, TransactionBehavior::Immediate).unwrap(); + tx2.execute_batch("INSERT INTO foo VALUES (42)").unwrap(); + rx.send(1).unwrap(); + let ten_millis = time::Duration::from_millis(10); + thread::sleep(ten_millis); + tx2.commit().unwrap(); + }); + assert_eq!(tx.recv().unwrap(), 1); + let the_answer: Result = db1.query_row("SELECT x FROM foo", [], |r| r.get(0)); + assert_eq!(42i64, the_answer?); + child.join().unwrap(); + Ok(()) + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 0000000..2b8dcfd --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,11 @@ +// Internal utilities +pub(crate) mod param_cache; +mod small_cstr; +pub(crate) use param_cache::ParamIndexCache; +pub(crate) use small_cstr::SmallCString; + +// Doesn't use any modern features or vtab stuff, but is only used by them. +#[cfg(any(feature = "modern_sqlite", feature = "vtab"))] +mod sqlite_string; +#[cfg(any(feature = "modern_sqlite", feature = "vtab"))] +pub(crate) use sqlite_string::SqliteMallocString; diff --git a/src/util/param_cache.rs b/src/util/param_cache.rs new file mode 100644 index 0000000..6faced9 --- /dev/null +++ b/src/util/param_cache.rs @@ -0,0 +1,60 @@ +use super::SmallCString; +use std::cell::RefCell; +use std::collections::BTreeMap; + +/// Maps parameter names to parameter indices. +#[derive(Default, Clone, Debug)] +// BTreeMap seems to do better here unless we want to pull in a custom hash +// function. +pub(crate) struct ParamIndexCache(RefCell>); + +impl ParamIndexCache { + pub fn get_or_insert_with(&self, s: &str, func: F) -> Option + where + F: FnOnce(&std::ffi::CStr) -> Option, + { + let mut cache = self.0.borrow_mut(); + // Avoid entry API, needs allocation to test membership. + if let Some(v) = cache.get(s) { + return Some(*v); + } + // If there's an internal nul in the name it couldn't have been a + // parameter, so early return here is ok. + let name = SmallCString::new(s).ok()?; + let val = func(&name)?; + cache.insert(name, val); + Some(val) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_cache() { + let p = ParamIndexCache::default(); + let v = p.get_or_insert_with("foo", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "foo"); + Some(3) + }); + assert_eq!(v, Some(3)); + let v = p.get_or_insert_with("foo", |_| { + panic!("shouldn't be called this time"); + }); + assert_eq!(v, Some(3)); + let v = p.get_or_insert_with("gar\0bage", |_| { + panic!("shouldn't be called here either"); + }); + assert_eq!(v, None); + let v = p.get_or_insert_with("bar", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "bar"); + None + }); + assert_eq!(v, None); + let v = p.get_or_insert_with("bar", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "bar"); + Some(30) + }); + assert_eq!(v, Some(30)); + } +} diff --git a/src/util/small_cstr.rs b/src/util/small_cstr.rs new file mode 100644 index 0000000..78e43bd --- /dev/null +++ b/src/util/small_cstr.rs @@ -0,0 +1,170 @@ +use smallvec::{smallvec, SmallVec}; +use std::ffi::{CStr, CString, NulError}; + +/// Similar to `std::ffi::CString`, but avoids heap allocating if the string is +/// small enough. Also guarantees it's input is UTF-8 -- used for cases where we +/// need to pass a NUL-terminated string to SQLite, and we have a `&str`. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct SmallCString(SmallVec<[u8; 16]>); + +impl SmallCString { + #[inline] + pub fn new(s: &str) -> Result { + if s.as_bytes().contains(&0_u8) { + return Err(Self::fabricate_nul_error(s)); + } + let mut buf = SmallVec::with_capacity(s.len() + 1); + buf.extend_from_slice(s.as_bytes()); + buf.push(0); + let res = Self(buf); + res.debug_checks(); + Ok(res) + } + + #[inline] + pub fn as_str(&self) -> &str { + self.debug_checks(); + // Constructor takes a &str so this is safe. + unsafe { std::str::from_utf8_unchecked(self.as_bytes_without_nul()) } + } + + /// Get the bytes not including the NUL terminator. E.g. the bytes which + /// make up our `str`: + /// - `SmallCString::new("foo").as_bytes_without_nul() == b"foo"` + /// - `SmallCString::new("foo").as_bytes_with_nul() == b"foo\0"` + #[inline] + pub fn as_bytes_without_nul(&self) -> &[u8] { + self.debug_checks(); + &self.0[..self.len()] + } + + /// Get the bytes behind this str *including* the NUL terminator. This + /// should never return an empty slice. + #[inline] + pub fn as_bytes_with_nul(&self) -> &[u8] { + self.debug_checks(); + &self.0 + } + + #[inline] + #[cfg(debug_assertions)] + fn debug_checks(&self) { + debug_assert_ne!(self.0.len(), 0); + debug_assert_eq!(self.0[self.0.len() - 1], 0); + let strbytes = &self.0[..(self.0.len() - 1)]; + debug_assert!(!strbytes.contains(&0)); + debug_assert!(std::str::from_utf8(strbytes).is_ok()); + } + + #[inline] + #[cfg(not(debug_assertions))] + fn debug_checks(&self) {} + + #[inline] + pub fn len(&self) -> usize { + debug_assert_ne!(self.0.len(), 0); + self.0.len() - 1 + } + + #[inline] + #[allow(unused)] // clippy wants this function. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] + pub fn as_cstr(&self) -> &CStr { + let bytes = self.as_bytes_with_nul(); + debug_assert!(CStr::from_bytes_with_nul(bytes).is_ok()); + unsafe { CStr::from_bytes_with_nul_unchecked(bytes) } + } + + #[cold] + fn fabricate_nul_error(b: &str) -> NulError { + CString::new(b).unwrap_err() + } +} + +impl Default for SmallCString { + #[inline] + fn default() -> Self { + Self(smallvec![0]) + } +} + +impl std::fmt::Debug for SmallCString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("SmallCString").field(&self.as_str()).finish() + } +} + +impl std::ops::Deref for SmallCString { + type Target = CStr; + + #[inline] + fn deref(&self) -> &CStr { + self.as_cstr() + } +} + +impl PartialEq for str { + #[inline] + fn eq(&self, s: &SmallCString) -> bool { + s.as_bytes_without_nul() == self.as_bytes() + } +} + +impl PartialEq for SmallCString { + #[inline] + fn eq(&self, s: &str) -> bool { + self.as_bytes_without_nul() == s.as_bytes() + } +} + +impl std::borrow::Borrow for SmallCString { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_small_cstring() { + // We don't go through the normal machinery for default, so make sure + // things work. + assert_eq!(SmallCString::default().0, SmallCString::new("").unwrap().0); + assert_eq!(SmallCString::new("foo").unwrap().len(), 3); + assert_eq!( + SmallCString::new("foo").unwrap().as_bytes_with_nul(), + b"foo\0" + ); + assert_eq!( + SmallCString::new("foo").unwrap().as_bytes_without_nul(), + b"foo", + ); + + assert_eq!(SmallCString::new("😀").unwrap().len(), 4); + assert_eq!( + SmallCString::new("😀").unwrap().0.as_slice(), + b"\xf0\x9f\x98\x80\0", + ); + assert_eq!( + SmallCString::new("😀").unwrap().as_bytes_without_nul(), + b"\xf0\x9f\x98\x80", + ); + + assert_eq!(SmallCString::new("").unwrap().len(), 0); + assert!(SmallCString::new("").unwrap().is_empty()); + + assert_eq!(SmallCString::new("").unwrap().0.as_slice(), b"\0"); + assert_eq!(SmallCString::new("").unwrap().as_bytes_without_nul(), b""); + + assert!(SmallCString::new("\0").is_err()); + assert!(SmallCString::new("\0abc").is_err()); + assert!(SmallCString::new("abc\0").is_err()); + } +} diff --git a/src/util/sqlite_string.rs b/src/util/sqlite_string.rs new file mode 100644 index 0000000..da261ba --- /dev/null +++ b/src/util/sqlite_string.rs @@ -0,0 +1,236 @@ +// This is used when either vtab or modern-sqlite is on. Different methods are +// used in each feature. Avoid having to track this for each function. We will +// still warn for anything that's not used by either, though. +#![cfg_attr( + not(all(feature = "vtab", feature = "modern-sqlite")), + allow(dead_code) +)] +use crate::ffi; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int}; +use std::ptr::NonNull; + +/// A string we own that's allocated on the SQLite heap. Automatically calls +/// `sqlite3_free` when dropped, unless `into_raw` (or `into_inner`) is called +/// on it. If constructed from a rust string, `sqlite3_malloc` is used. +/// +/// It has identical representation to a nonnull `*mut c_char`, so you can use +/// it transparently as one. It's nonnull, so Option can be +/// used for nullable ones (it's still just one pointer). +/// +/// Most strings shouldn't use this! Only places where the string needs to be +/// freed with `sqlite3_free`. This includes `sqlite3_extended_sql` results, +/// some error message pointers... Note that misuse is extremely dangerous! +/// +/// Note that this is *not* a lossless interface. Incoming strings with internal +/// NULs are modified, and outgoing strings which are non-UTF8 are modified. +/// This seems unavoidable -- it tries very hard to not panic. +#[repr(transparent)] +pub(crate) struct SqliteMallocString { + ptr: NonNull, + _boo: PhantomData>, +} +// This is owned data for a primitive type, and thus it's safe to implement +// these. That said, nothing needs them, and they make things easier to misuse. + +// unsafe impl Send for SqliteMallocString {} +// unsafe impl Sync for SqliteMallocString {} + +impl SqliteMallocString { + /// SAFETY: Caller must be certain that `m` a nul-terminated c string + /// allocated by `sqlite3_malloc`, and that SQLite expects us to free it! + #[inline] + pub(crate) unsafe fn from_raw_nonnull(ptr: NonNull) -> Self { + Self { + ptr, + _boo: PhantomData, + } + } + + /// SAFETY: Caller must be certain that `m` a nul-terminated c string + /// allocated by `sqlite3_malloc`, and that SQLite expects us to free it! + #[inline] + pub(crate) unsafe fn from_raw(ptr: *mut c_char) -> Option { + NonNull::new(ptr).map(|p| Self::from_raw_nonnull(p)) + } + + /// Get the pointer behind `self`. After this is called, we no longer manage + /// it. + #[inline] + pub(crate) fn into_inner(self) -> NonNull { + let p = self.ptr; + std::mem::forget(self); + p + } + + /// Get the pointer behind `self`. After this is called, we no longer manage + /// it. + #[inline] + pub(crate) fn into_raw(self) -> *mut c_char { + self.into_inner().as_ptr() + } + + /// Borrow the pointer behind `self`. We still manage it when this function + /// returns. If you want to relinquish ownership, use `into_raw`. + #[inline] + pub(crate) fn as_ptr(&self) -> *const c_char { + self.ptr.as_ptr() + } + + #[inline] + pub(crate) fn as_cstr(&self) -> &std::ffi::CStr { + unsafe { std::ffi::CStr::from_ptr(self.as_ptr()) } + } + + #[inline] + pub(crate) fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + self.as_cstr().to_string_lossy() + } + + /// Convert `s` into a SQLite string. + /// + /// This should almost never be done except for cases like error messages or + /// other strings that SQLite frees. + /// + /// If `s` contains internal NULs, we'll replace them with + /// `NUL_REPLACE_CHAR`. + /// + /// Except for `debug_assert`s which may trigger during testing, this + /// function never panics. If we hit integer overflow or the allocation + /// fails, we call `handle_alloc_error` which aborts the program after + /// calling a global hook. + /// + /// This means it's safe to use in extern "C" functions even outside of + /// `catch_unwind`. + pub(crate) fn from_str(s: &str) -> Self { + use std::convert::TryFrom; + let s = if s.as_bytes().contains(&0) { + std::borrow::Cow::Owned(make_nonnull(s)) + } else { + std::borrow::Cow::Borrowed(s) + }; + debug_assert!(!s.as_bytes().contains(&0)); + let bytes: &[u8] = s.as_ref().as_bytes(); + let src_ptr: *const c_char = bytes.as_ptr().cast(); + let src_len = bytes.len(); + let maybe_len_plus_1 = s.len().checked_add(1).and_then(|v| c_int::try_from(v).ok()); + unsafe { + let res_ptr = maybe_len_plus_1 + .and_then(|len_to_alloc| { + // `>` because we added 1. + debug_assert!(len_to_alloc > 0); + debug_assert_eq!((len_to_alloc - 1) as usize, src_len); + NonNull::new(ffi::sqlite3_malloc(len_to_alloc).cast::()) + }) + .unwrap_or_else(|| { + use std::alloc::{handle_alloc_error, Layout}; + // Report via handle_alloc_error so that it can be handled with any + // other allocation errors and properly diagnosed. + // + // This is safe: + // - `align` is never 0 + // - `align` is always a power of 2. + // - `size` needs no realignment because it's guaranteed to be aligned + // (everything is aligned to 1) + // - `size` is also never zero, although this function doesn't actually require + // it now. + let layout = Layout::from_size_align_unchecked(s.len().saturating_add(1), 1); + // Note: This call does not return. + handle_alloc_error(layout); + }); + let buf: *mut c_char = res_ptr.as_ptr().cast::(); + src_ptr.copy_to_nonoverlapping(buf, src_len); + buf.add(src_len).write(0); + debug_assert_eq!(std::ffi::CStr::from_ptr(res_ptr.as_ptr()).to_bytes(), bytes); + Self::from_raw_nonnull(res_ptr) + } + } +} + +const NUL_REPLACE: &str = "␀"; + +#[cold] +fn make_nonnull(v: &str) -> String { + v.replace('\0', NUL_REPLACE) +} + +impl Drop for SqliteMallocString { + #[inline] + fn drop(&mut self) { + unsafe { ffi::sqlite3_free(self.ptr.as_ptr().cast()) }; + } +} + +impl std::fmt::Debug for SqliteMallocString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_string_lossy().fmt(f) + } +} + +impl std::fmt::Display for SqliteMallocString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_string_lossy().fmt(f) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_from_str() { + let to_check = [ + ("", ""), + ("\0", "␀"), + ("␀", "␀"), + ("\0bar", "␀bar"), + ("foo\0bar", "foo␀bar"), + ("foo\0", "foo␀"), + ("a\0b\0c\0\0d", "a␀b␀c␀␀d"), + ("foobar0123", "foobar0123"), + ]; + + for &(input, output) in &to_check { + let s = SqliteMallocString::from_str(input); + assert_eq!(s.to_string_lossy(), output); + assert_eq!(s.as_cstr().to_str().unwrap(), output); + } + } + + // This will trigger an asan error if into_raw still freed the ptr. + #[test] + fn test_lossy() { + let p = SqliteMallocString::from_str("abcd").into_raw(); + // Make invalid + let s = unsafe { + p.cast::().write(b'\xff'); + SqliteMallocString::from_raw(p).unwrap() + }; + assert_eq!(s.to_string_lossy().as_ref(), "\u{FFFD}bcd"); + } + + // This will trigger an asan error if into_raw still freed the ptr. + #[test] + fn test_into_raw() { + let mut v = vec![]; + for i in 0..1000 { + v.push(SqliteMallocString::from_str(&i.to_string()).into_raw()); + v.push(SqliteMallocString::from_str(&format!("abc {} 😀", i)).into_raw()); + } + unsafe { + for (i, s) in v.chunks_mut(2).enumerate() { + let s0 = std::mem::replace(&mut s[0], std::ptr::null_mut()); + let s1 = std::mem::replace(&mut s[1], std::ptr::null_mut()); + assert_eq!( + std::ffi::CStr::from_ptr(s0).to_str().unwrap(), + &i.to_string() + ); + assert_eq!( + std::ffi::CStr::from_ptr(s1).to_str().unwrap(), + &format!("abc {} 😀", i) + ); + let _ = SqliteMallocString::from_raw(s0).unwrap(); + let _ = SqliteMallocString::from_raw(s1).unwrap(); + } + } + } +} diff --git a/src/version.rs b/src/version.rs new file mode 100644 index 0000000..d70af7e --- /dev/null +++ b/src/version.rs @@ -0,0 +1,23 @@ +use crate::ffi; +use std::ffi::CStr; + +/// Returns the SQLite version as an integer; e.g., `3016002` for version +/// 3.16.2. +/// +/// See [`sqlite3_libversion_number()`](https://www.sqlite.org/c3ref/libversion.html). +#[inline] +#[must_use] +pub fn version_number() -> i32 { + unsafe { ffi::sqlite3_libversion_number() } +} + +/// Returns the SQLite version as a string; e.g., `"3.16.2"` for version 3.16.2. +/// +/// See [`sqlite3_libversion()`](https://www.sqlite.org/c3ref/libversion.html). +#[inline] +#[must_use] +pub fn version() -> &'static str { + let cstr = unsafe { CStr::from_ptr(ffi::sqlite3_libversion()) }; + cstr.to_str() + .expect("SQLite version string is not valid UTF8 ?!") +} diff --git a/src/vtab/array.rs b/src/vtab/array.rs new file mode 100644 index 0000000..f09ac1a --- /dev/null +++ b/src/vtab/array.rs @@ -0,0 +1,223 @@ +//! Array Virtual Table. +//! +//! Note: `rarray`, not `carray` is the name of the table valued function we +//! define. +//! +//! Port of [carray](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/carray.c) +//! C extension: `https://www.sqlite.org/carray.html` +//! +//! # Example +//! +//! ```rust,no_run +//! # use rusqlite::{types::Value, Connection, Result, params}; +//! # use std::rc::Rc; +//! fn example(db: &Connection) -> Result<()> { +//! // Note: This should be done once (usually when opening the DB). +//! rusqlite::vtab::array::load_module(&db)?; +//! let v = [1i64, 2, 3, 4]; +//! // Note: A `Rc>` must be used as the parameter. +//! let values = Rc::new(v.iter().copied().map(Value::from).collect::>()); +//! let mut stmt = db.prepare("SELECT value from rarray(?);")?; +//! let rows = stmt.query_map([values], |row| row.get::<_, i64>(0))?; +//! for value in rows { +//! println!("{}", value?); +//! } +//! Ok(()) +//! } +//! ``` + +use std::default::Default; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int, c_void}; +use std::rc::Rc; + +use crate::ffi; +use crate::types::{ToSql, ToSqlOutput, Value}; +use crate::vtab::{ + eponymous_only_module, Context, IndexConstraintOp, IndexInfo, VTab, VTabConnection, VTabCursor, + Values, +}; +use crate::{Connection, Result}; + +// http://sqlite.org/bindptr.html + +pub(crate) const ARRAY_TYPE: *const c_char = (b"rarray\0" as *const u8).cast::(); + +pub(crate) unsafe extern "C" fn free_array(p: *mut c_void) { + drop(Rc::from_raw(p as *const Vec)); +} + +/// Array parameter / pointer +pub type Array = Rc>; + +impl ToSql for Array { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::Array(self.clone())) + } +} + +/// Register the "rarray" module. +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("rarray", eponymous_only_module::(), aux) +} + +// Column numbers +// const CARRAY_COLUMN_VALUE : c_int = 0; +const CARRAY_COLUMN_POINTER: c_int = 1; + +/// An instance of the Array virtual table +#[repr(C)] +struct ArrayTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for ArrayTab { + type Aux = (); + type Cursor = ArrayTabCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + _args: &[&[u8]], + ) -> Result<(String, ArrayTab)> { + let vtab = ArrayTab { + base: ffi::sqlite3_vtab::default(), + }; + Ok(("CREATE TABLE x(value,pointer hidden)".to_owned(), vtab)) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + // Index of the pointer= constraint + let mut ptr_idx = false; + for (constraint, mut constraint_usage) in info.constraints_and_usages() { + if !constraint.is_usable() { + continue; + } + if constraint.operator() != IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ { + continue; + } + if let CARRAY_COLUMN_POINTER = constraint.column() { + ptr_idx = true; + constraint_usage.set_argv_index(1); + constraint_usage.set_omit(true); + } + } + if ptr_idx { + info.set_estimated_cost(1_f64); + info.set_estimated_rows(100); + info.set_idx_num(1); + } else { + info.set_estimated_cost(2_147_483_647_f64); + info.set_estimated_rows(2_147_483_647); + info.set_idx_num(0); + } + Ok(()) + } + + fn open(&mut self) -> Result> { + Ok(ArrayTabCursor::new()) + } +} + +/// A cursor for the Array virtual table +#[repr(C)] +struct ArrayTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + /// Pointer to the array of values ("pointer") + ptr: Option, + phantom: PhantomData<&'vtab ArrayTab>, +} + +impl ArrayTabCursor<'_> { + fn new<'vtab>() -> ArrayTabCursor<'vtab> { + ArrayTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + row_id: 0, + ptr: None, + phantom: PhantomData, + } + } + + fn len(&self) -> i64 { + match self.ptr { + Some(ref a) => a.len() as i64, + _ => 0, + } + } +} +unsafe impl VTabCursor for ArrayTabCursor<'_> { + fn filter(&mut self, idx_num: c_int, _idx_str: Option<&str>, args: &Values<'_>) -> Result<()> { + if idx_num > 0 { + self.ptr = args.get_array(0); + } else { + self.ptr = None; + } + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.row_id > self.len() + } + + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()> { + match i { + CARRAY_COLUMN_POINTER => Ok(()), + _ => { + if let Some(ref array) = self.ptr { + let value = &array[(self.row_id - 1) as usize]; + ctx.set_result(&value) + } else { + Ok(()) + } + } + } + } + + fn rowid(&self) -> Result { + Ok(self.row_id) + } +} + +#[cfg(test)] +mod test { + use crate::types::Value; + use crate::vtab::array; + use crate::{Connection, Result}; + use std::rc::Rc; + + #[test] + fn test_array_module() -> Result<()> { + let db = Connection::open_in_memory()?; + array::load_module(&db)?; + + let v = vec![1i64, 2, 3, 4]; + let values: Vec = v.into_iter().map(Value::from).collect(); + let ptr = Rc::new(values); + { + let mut stmt = db.prepare("SELECT value from rarray(?);")?; + + let rows = stmt.query_map(&[&ptr], |row| row.get::<_, i64>(0))?; + assert_eq!(2, Rc::strong_count(&ptr)); + let mut count = 0; + for (i, value) in rows.enumerate() { + assert_eq!(i as i64, value? - 1); + count += 1; + } + assert_eq!(4, count); + } + assert_eq!(1, Rc::strong_count(&ptr)); + Ok(()) + } +} diff --git a/src/vtab/csvtab.rs b/src/vtab/csvtab.rs new file mode 100644 index 0000000..a65db05 --- /dev/null +++ b/src/vtab/csvtab.rs @@ -0,0 +1,396 @@ +//! CSV Virtual Table. +//! +//! Port of [csv](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/csv.c) C +//! extension: `https://www.sqlite.org/csv.html` +//! +//! # Example +//! +//! ```rust,no_run +//! # use rusqlite::{Connection, Result}; +//! fn example() -> Result<()> { +//! // Note: This should be done once (usually when opening the DB). +//! let db = Connection::open_in_memory()?; +//! rusqlite::vtab::csvtab::load_module(&db)?; +//! // Assume my_csv.csv +//! let schema = " +//! CREATE VIRTUAL TABLE my_csv_data +//! USING csv(filename = 'my_csv.csv') +//! "; +//! db.execute_batch(schema)?; +//! // Now the `my_csv_data` (virtual) table can be queried as normal... +//! Ok(()) +//! } +//! ``` +use std::fs::File; +use std::marker::PhantomData; +use std::os::raw::c_int; +use std::path::Path; +use std::str; + +use crate::ffi; +use crate::types::Null; +use crate::vtab::{ + escape_double_quote, parse_boolean, read_only_module, Context, CreateVTab, IndexInfo, VTab, + VTabConfig, VTabConnection, VTabCursor, VTabKind, Values, +}; +use crate::{Connection, Error, Result}; + +/// Register the "csv" module. +/// ```sql +/// CREATE VIRTUAL TABLE vtab USING csv( +/// filename=FILENAME -- Name of file containing CSV content +/// [, schema=SCHEMA] -- Alternative CSV schema. 'CREATE TABLE x(col1 TEXT NOT NULL, col2 INT, ...);' +/// [, header=YES|NO] -- First row of CSV defines the names of columns if "yes". Default "no". +/// [, columns=N] -- Assume the CSV file contains N columns. +/// [, delimiter=C] -- CSV delimiter. Default ','. +/// [, quote=C] -- CSV quote. Default '"'. 0 means no quote. +/// ); +/// ``` +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("csv", read_only_module::(), aux) +} + +/// An instance of the CSV virtual table +#[repr(C)] +struct CsvTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, + /// Name of the CSV file + filename: String, + has_headers: bool, + delimiter: u8, + quote: u8, + /// Offset to start of data + offset_first_row: csv::Position, +} + +impl CsvTab { + fn reader(&self) -> Result, csv::Error> { + csv::ReaderBuilder::new() + .has_headers(self.has_headers) + .delimiter(self.delimiter) + .quote(self.quote) + .from_path(&self.filename) + } + + fn parse_byte(arg: &str) -> Option { + if arg.len() == 1 { + arg.bytes().next() + } else { + None + } + } +} + +unsafe impl<'vtab> VTab<'vtab> for CsvTab { + type Aux = (); + type Cursor = CsvTabCursor<'vtab>; + + fn connect( + db: &mut VTabConnection, + _aux: Option<&()>, + args: &[&[u8]], + ) -> Result<(String, CsvTab)> { + if args.len() < 4 { + return Err(Error::ModuleError("no CSV file specified".to_owned())); + } + + let mut vtab = CsvTab { + base: ffi::sqlite3_vtab::default(), + filename: "".to_owned(), + has_headers: false, + delimiter: b',', + quote: b'"', + offset_first_row: csv::Position::new(), + }; + let mut schema = None; + let mut n_col = None; + + let args = &args[3..]; + for c_slice in args { + let (param, value) = super::parameter(c_slice)?; + match param { + "filename" => { + if !Path::new(value).exists() { + return Err(Error::ModuleError(format!( + "file '{}' does not exist", + value + ))); + } + vtab.filename = value.to_owned(); + } + "schema" => { + schema = Some(value.to_owned()); + } + "columns" => { + if let Ok(n) = value.parse::() { + if n_col.is_some() { + return Err(Error::ModuleError( + "more than one 'columns' parameter".to_owned(), + )); + } else if n == 0 { + return Err(Error::ModuleError( + "must have at least one column".to_owned(), + )); + } + n_col = Some(n); + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'columns': {}", + value + ))); + } + } + "header" => { + if let Some(b) = parse_boolean(value) { + vtab.has_headers = b; + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'header': {}", + value + ))); + } + } + "delimiter" => { + if let Some(b) = CsvTab::parse_byte(value) { + vtab.delimiter = b; + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'delimiter': {}", + value + ))); + } + } + "quote" => { + if let Some(b) = CsvTab::parse_byte(value) { + if b == b'0' { + vtab.quote = 0; + } else { + vtab.quote = b; + } + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'quote': {}", + value + ))); + } + } + _ => { + return Err(Error::ModuleError(format!( + "unrecognized parameter '{}'", + param + ))); + } + } + } + + if vtab.filename.is_empty() { + return Err(Error::ModuleError("no CSV file specified".to_owned())); + } + + let mut cols: Vec = Vec::new(); + if vtab.has_headers || (n_col.is_none() && schema.is_none()) { + let mut reader = vtab.reader()?; + if vtab.has_headers { + { + let headers = reader.headers()?; + // headers ignored if cols is not empty + if n_col.is_none() && schema.is_none() { + cols = headers + .into_iter() + .map(|header| escape_double_quote(header).into_owned()) + .collect(); + } + } + vtab.offset_first_row = reader.position().clone(); + } else { + let mut record = csv::ByteRecord::new(); + if reader.read_byte_record(&mut record)? { + for (i, _) in record.iter().enumerate() { + cols.push(format!("c{}", i)); + } + } + } + } else if let Some(n_col) = n_col { + for i in 0..n_col { + cols.push(format!("c{}", i)); + } + } + + if cols.is_empty() && schema.is_none() { + return Err(Error::ModuleError("no column specified".to_owned())); + } + + if schema.is_none() { + let mut sql = String::from("CREATE TABLE x("); + for (i, col) in cols.iter().enumerate() { + sql.push('"'); + sql.push_str(col); + sql.push_str("\" TEXT"); + if i == cols.len() - 1 { + sql.push_str(");"); + } else { + sql.push_str(", "); + } + } + schema = Some(sql); + } + db.config(VTabConfig::DirectOnly)?; + Ok((schema.unwrap(), vtab)) + } + + // Only a forward full table scan is supported. + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + info.set_estimated_cost(1_000_000.); + Ok(()) + } + + fn open(&mut self) -> Result> { + Ok(CsvTabCursor::new(self.reader()?)) + } +} + +impl CreateVTab<'_> for CsvTab { + const KIND: VTabKind = VTabKind::Default; +} + +/// A cursor for the CSV virtual table +#[repr(C)] +struct CsvTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// The CSV reader object + reader: csv::Reader, + /// Current cursor position used as rowid + row_number: usize, + /// Values of the current row + cols: csv::StringRecord, + eof: bool, + phantom: PhantomData<&'vtab CsvTab>, +} + +impl CsvTabCursor<'_> { + fn new<'vtab>(reader: csv::Reader) -> CsvTabCursor<'vtab> { + CsvTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + reader, + row_number: 0, + cols: csv::StringRecord::new(), + eof: false, + phantom: PhantomData, + } + } + + /// Accessor to the associated virtual table. + fn vtab(&self) -> &CsvTab { + unsafe { &*(self.base.pVtab as *const CsvTab) } + } +} + +unsafe impl VTabCursor for CsvTabCursor<'_> { + // Only a full table scan is supported. So `filter` simply rewinds to + // the beginning. + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> Result<()> { + { + let offset_first_row = self.vtab().offset_first_row.clone(); + self.reader.seek(offset_first_row)?; + } + self.row_number = 0; + self.next() + } + + fn next(&mut self) -> Result<()> { + { + self.eof = self.reader.is_done(); + if self.eof { + return Ok(()); + } + + self.eof = !self.reader.read_record(&mut self.cols)?; + } + + self.row_number += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.eof + } + + fn column(&self, ctx: &mut Context, col: c_int) -> Result<()> { + if col < 0 || col as usize >= self.cols.len() { + return Err(Error::ModuleError(format!( + "column index out of bounds: {}", + col + ))); + } + if self.cols.is_empty() { + return ctx.set_result(&Null); + } + // TODO Affinity + ctx.set_result(&self.cols[col as usize].to_owned()) + } + + fn rowid(&self) -> Result { + Ok(self.row_number as i64) + } +} + +impl From for Error { + #[cold] + fn from(err: csv::Error) -> Error { + Error::ModuleError(err.to_string()) + } +} + +#[cfg(test)] +mod test { + use crate::vtab::csvtab; + use crate::{Connection, Result}; + use fallible_iterator::FallibleIterator; + + #[test] + fn test_csv_module() -> Result<()> { + let db = Connection::open_in_memory()?; + csvtab::load_module(&db)?; + db.execute_batch("CREATE VIRTUAL TABLE vtab USING csv(filename='test.csv', header=yes)")?; + + { + let mut s = db.prepare("SELECT rowid, * FROM vtab")?; + { + let headers = s.column_names(); + assert_eq!(vec!["rowid", "colA", "colB", "colC"], headers); + } + + let ids: Result> = s.query([])?.map(|row| row.get::<_, i32>(0)).collect(); + let sum = ids?.iter().sum::(); + assert_eq!(sum, 15); + } + db.execute_batch("DROP TABLE vtab") + } + + #[test] + fn test_csv_cursor() -> Result<()> { + let db = Connection::open_in_memory()?; + csvtab::load_module(&db)?; + db.execute_batch("CREATE VIRTUAL TABLE vtab USING csv(filename='test.csv', header=yes)")?; + + { + let mut s = db.prepare( + "SELECT v1.rowid, v1.* FROM vtab v1 NATURAL JOIN vtab v2 WHERE \ + v1.rowid < v2.rowid", + )?; + + let mut rows = s.query([])?; + let row = rows.next()?.unwrap(); + assert_eq!(row.get_unwrap::<_, i32>(0), 2); + } + db.execute_batch("DROP TABLE vtab") + } +} diff --git a/src/vtab/mod.rs b/src/vtab/mod.rs new file mode 100644 index 0000000..07008f3 --- /dev/null +++ b/src/vtab/mod.rs @@ -0,0 +1,1366 @@ +//! Create virtual tables. +//! +//! Follow these steps to create your own virtual table: +//! 1. Write implementation of [`VTab`] and [`VTabCursor`] traits. +//! 2. Create an instance of the [`Module`] structure specialized for [`VTab`] +//! impl. from step 1. +//! 3. Register your [`Module`] structure using [`Connection::create_module`]. +//! 4. Run a `CREATE VIRTUAL TABLE` command that specifies the new module in the +//! `USING` clause. +//! +//! (See [SQLite doc](http://sqlite.org/vtab.html)) +use std::borrow::Cow::{self, Borrowed, Owned}; +use std::marker::PhantomData; +use std::marker::Sync; +use std::os::raw::{c_char, c_int, c_void}; +use std::ptr; +use std::slice; + +use crate::context::set_result; +use crate::error::error_from_sqlite_code; +use crate::ffi; +pub use crate::ffi::{sqlite3_vtab, sqlite3_vtab_cursor}; +use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; +use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; + +// let conn: Connection = ...; +// let mod: Module = ...; // VTab builder +// conn.create_module("module", mod); +// +// conn.execute("CREATE VIRTUAL TABLE foo USING module(...)"); +// \-> Module::xcreate +// |-> let vtab: VTab = ...; // on the heap +// \-> conn.declare_vtab("CREATE TABLE foo (...)"); +// conn = Connection::open(...); +// \-> Module::xconnect +// |-> let vtab: VTab = ...; // on the heap +// \-> conn.declare_vtab("CREATE TABLE foo (...)"); +// +// conn.close(); +// \-> vtab.xdisconnect +// conn.execute("DROP TABLE foo"); +// \-> vtab.xDestroy +// +// let stmt = conn.prepare("SELECT ... FROM foo WHERE ..."); +// \-> vtab.xbestindex +// stmt.query().next(); +// \-> vtab.xopen +// |-> let cursor: VTabCursor = ...; // on the heap +// |-> cursor.xfilter or xnext +// |-> cursor.xeof +// \-> if not eof { cursor.column or xrowid } else { cursor.xclose } +// + +// db: *mut ffi::sqlite3 => VTabConnection +// module: *const ffi::sqlite3_module => Module +// aux: *mut c_void => Module::Aux +// ffi::sqlite3_vtab => VTab +// ffi::sqlite3_vtab_cursor => VTabCursor + +/// Virtual table kind +pub enum VTabKind { + /// Non-eponymous + Default, + /// [`create`](CreateVTab::create) == [`connect`](VTab::connect) + /// + /// See [SQLite doc](https://sqlite.org/vtab.html#eponymous_virtual_tables) + Eponymous, + /// No [`create`](CreateVTab::create) / [`destroy`](CreateVTab::destroy) or + /// not used + /// + /// SQLite >= 3.9.0 + /// + /// See [SQLite doc](https://sqlite.org/vtab.html#eponymous_only_virtual_tables) + EponymousOnly, +} + +/// Virtual table module +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/module.html)) +#[repr(transparent)] +pub struct Module<'vtab, T: VTab<'vtab>> { + base: ffi::sqlite3_module, + phantom: PhantomData<&'vtab T>, +} + +unsafe impl<'vtab, T: VTab<'vtab>> Send for Module<'vtab, T> {} +unsafe impl<'vtab, T: VTab<'vtab>> Sync for Module<'vtab, T> {} + +union ModuleZeroHack { + bytes: [u8; std::mem::size_of::()], + module: ffi::sqlite3_module, +} + +// Used as a trailing initializer for sqlite3_module -- this way we avoid having +// the build fail if buildtime_bindgen is on. This is safe, as bindgen-generated +// structs are allowed to be zeroed. +const ZERO_MODULE: ffi::sqlite3_module = unsafe { + ModuleZeroHack { + bytes: [0_u8; std::mem::size_of::()], + } + .module +}; + +macro_rules! module { + ($lt:lifetime, $vt:ty, $ct:ty, $xc:expr, $xd:expr, $xu:expr) => { + #[allow(clippy::needless_update)] + &Module { + base: ffi::sqlite3_module { + // We don't use V3 + iVersion: 2, + xCreate: $xc, + xConnect: Some(rust_connect::<$vt>), + xBestIndex: Some(rust_best_index::<$vt>), + xDisconnect: Some(rust_disconnect::<$vt>), + xDestroy: $xd, + xOpen: Some(rust_open::<$vt>), + xClose: Some(rust_close::<$ct>), + xFilter: Some(rust_filter::<$ct>), + xNext: Some(rust_next::<$ct>), + xEof: Some(rust_eof::<$ct>), + xColumn: Some(rust_column::<$ct>), + xRowid: Some(rust_rowid::<$ct>), // FIXME optional + xUpdate: $xu, + xBegin: None, + xSync: None, + xCommit: None, + xRollback: None, + xFindFunction: None, + xRename: None, + xSavepoint: None, + xRelease: None, + xRollbackTo: None, + ..ZERO_MODULE + }, + phantom: PhantomData::<&$lt $vt>, + } + }; +} + +/// Create an modifiable virtual table implementation. +/// +/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). +#[must_use] +pub fn update_module<'vtab, T: UpdateVTab<'vtab>>() -> &'static Module<'vtab, T> { + match T::KIND { + VTabKind::EponymousOnly => { + module!('vtab, T, T::Cursor, None, None, Some(rust_update::)) + } + VTabKind::Eponymous => { + module!('vtab, T, T::Cursor, Some(rust_connect::), Some(rust_disconnect::), Some(rust_update::)) + } + _ => { + module!('vtab, T, T::Cursor, Some(rust_create::), Some(rust_destroy::), Some(rust_update::)) + } + } +} + +/// Create a read-only virtual table implementation. +/// +/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). +#[must_use] +pub fn read_only_module<'vtab, T: CreateVTab<'vtab>>() -> &'static Module<'vtab, T> { + match T::KIND { + VTabKind::EponymousOnly => eponymous_only_module(), + VTabKind::Eponymous => { + // A virtual table is eponymous if its xCreate method is the exact same function + // as the xConnect method + module!('vtab, T, T::Cursor, Some(rust_connect::), Some(rust_disconnect::), None) + } + _ => { + // The xConnect and xCreate methods may do the same thing, but they must be + // different so that the virtual table is not an eponymous virtual table. + module!('vtab, T, T::Cursor, Some(rust_create::), Some(rust_destroy::), None) + } + } +} + +/// Create an eponymous only virtual table implementation. +/// +/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). +#[must_use] +pub fn eponymous_only_module<'vtab, T: VTab<'vtab>>() -> &'static Module<'vtab, T> { + // For eponymous-only virtual tables, the xCreate method is NULL + module!('vtab, T, T::Cursor, None, None, None) +} + +/// Virtual table configuration options +#[repr(i32)] +#[non_exhaustive] +#[cfg(feature = "modern_sqlite")] // 3.7.7 +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum VTabConfig { + /// Equivalent to SQLITE_VTAB_CONSTRAINT_SUPPORT + ConstraintSupport = 1, + /// Equivalent to SQLITE_VTAB_INNOCUOUS + Innocuous = 2, + /// Equivalent to SQLITE_VTAB_DIRECTONLY + DirectOnly = 3, +} + +/// `feature = "vtab"` +pub struct VTabConnection(*mut ffi::sqlite3); + +impl VTabConnection { + /// Configure various facets of the virtual table interface + #[cfg(feature = "modern_sqlite")] // 3.7.7 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn config(&mut self, config: VTabConfig) -> Result<()> { + crate::error::check(unsafe { ffi::sqlite3_vtab_config(self.0, config as c_int) }) + } + + // TODO sqlite3_vtab_on_conflict (http://sqlite.org/c3ref/vtab_on_conflict.html) & xUpdate + + /// Get access to the underlying SQLite database connection handle. + /// + /// # Warning + /// + /// You should not need to use this function. If you do need to, please + /// [open an issue on the rusqlite repository](https://github.com/rusqlite/rusqlite/issues) and describe + /// your use case. + /// + /// # Safety + /// + /// This function is unsafe because it gives you raw access + /// to the SQLite connection, and what you do with it could impact the + /// safety of this `Connection`. + pub unsafe fn handle(&mut self) -> *mut ffi::sqlite3 { + self.0 + } +} + +/// Eponymous-only virtual table instance trait. +/// +/// # Safety +/// +/// The first item in a struct implementing `VTab` must be +/// `rusqlite::sqlite3_vtab`, and the struct must be `#[repr(C)]`. +/// +/// ```rust,ignore +/// #[repr(C)] +/// struct MyTab { +/// /// Base class. Must be first +/// base: rusqlite::vtab::sqlite3_vtab, +/// /* Virtual table implementations will typically add additional fields */ +/// } +/// ``` +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab.html)) +pub unsafe trait VTab<'vtab>: Sized { + /// Client data passed to [`Connection::create_module`]. + type Aux; + /// Specific cursor implementation + type Cursor: VTabCursor; + + /// Establish a new connection to an existing virtual table. + /// + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xconnect_method)) + fn connect( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)>; + + /// Determine the best way to access the virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xbestindex_method)) + fn best_index(&self, info: &mut IndexInfo) -> Result<()>; + + /// Create a new cursor used for accessing a virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xopen_method)) + fn open(&'vtab mut self) -> Result; +} + +/// Read-only virtual table instance trait. +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab.html)) +pub trait CreateVTab<'vtab>: VTab<'vtab> { + /// For [`EponymousOnly`](VTabKind::EponymousOnly), + /// [`create`](CreateVTab::create) and [`destroy`](CreateVTab::destroy) are + /// not called + const KIND: VTabKind; + /// Create a new instance of a virtual table in response to a CREATE VIRTUAL + /// TABLE statement. The `db` parameter is a pointer to the SQLite + /// database connection that is executing the CREATE VIRTUAL TABLE + /// statement. + /// + /// Call [`connect`](VTab::connect) by default. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xcreate_method)) + fn create( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)> { + Self::connect(db, aux, args) + } + + /// Destroy the underlying table implementation. This method undoes the work + /// of [`create`](CreateVTab::create). + /// + /// Do nothing by default. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xdestroy_method)) + fn destroy(&self) -> Result<()> { + Ok(()) + } +} + +/// Writable virtual table instance trait. +/// +/// (See [SQLite doc](https://sqlite.org/vtab.html#xupdate)) +pub trait UpdateVTab<'vtab>: CreateVTab<'vtab> { + /// Delete rowid or PK + fn delete(&mut self, arg: ValueRef<'_>) -> Result<()>; + /// Insert: `args[0] == NULL: old rowid or PK, args[1]: new rowid or PK, + /// args[2]: ...` + /// + /// Return the new rowid. + // TODO Make the distinction between argv[1] == NULL and argv[1] != NULL ? + fn insert(&mut self, args: &Values<'_>) -> Result; + /// Update: `args[0] != NULL: old rowid or PK, args[1]: new row id or PK, + /// args[2]: ...` + fn update(&mut self, args: &Values<'_>) -> Result<()>; +} + +/// Index constraint operator. +/// See [Virtual Table Constraint Operator Codes](https://sqlite.org/c3ref/c_index_constraint_eq.html) for details. +#[derive(Debug, Eq, PartialEq)] +#[allow(non_snake_case, non_camel_case_types, missing_docs)] +#[allow(clippy::upper_case_acronyms)] +pub enum IndexConstraintOp { + SQLITE_INDEX_CONSTRAINT_EQ, + SQLITE_INDEX_CONSTRAINT_GT, + SQLITE_INDEX_CONSTRAINT_LE, + SQLITE_INDEX_CONSTRAINT_LT, + SQLITE_INDEX_CONSTRAINT_GE, + SQLITE_INDEX_CONSTRAINT_MATCH, + SQLITE_INDEX_CONSTRAINT_LIKE, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_GLOB, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_REGEXP, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_NE, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNOT, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNOTNULL, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNULL, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_IS, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_LIMIT, // 3.38.0 + SQLITE_INDEX_CONSTRAINT_OFFSET, // 3.38.0 + SQLITE_INDEX_CONSTRAINT_FUNCTION(u8), // 3.25.0 +} + +impl From for IndexConstraintOp { + fn from(code: u8) -> IndexConstraintOp { + match code { + 2 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ, + 4 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GT, + 8 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LE, + 16 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LT, + 32 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GE, + 64 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_MATCH, + 65 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LIKE, + 66 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GLOB, + 67 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_REGEXP, + 68 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_NE, + 69 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNOT, + 70 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNOTNULL, + 71 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNULL, + 72 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_IS, + 73 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LIMIT, + 74 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_OFFSET, + v => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_FUNCTION(v), + } + } +} + +#[cfg(feature = "modern_sqlite")] // 3.9.0 +bitflags::bitflags! { + /// Virtual table scan flags + /// See [Function Flags](https://sqlite.org/c3ref/c_index_scan_unique.html) for details. + #[repr(C)] + pub struct IndexFlags: ::std::os::raw::c_int { + /// Default + const NONE = 0; + /// Scan visits at most 1 row. + const SQLITE_INDEX_SCAN_UNIQUE = ffi::SQLITE_INDEX_SCAN_UNIQUE; + } +} + +/// Pass information into and receive the reply from the +/// [`VTab::best_index`] method. +/// +/// (See [SQLite doc](http://sqlite.org/c3ref/index_info.html)) +#[derive(Debug)] +pub struct IndexInfo(*mut ffi::sqlite3_index_info); + +impl IndexInfo { + /// Iterate on index constraint and its associated usage. + #[inline] + pub fn constraints_and_usages(&mut self) -> IndexConstraintAndUsageIter<'_> { + let constraints = + unsafe { slice::from_raw_parts((*self.0).aConstraint, (*self.0).nConstraint as usize) }; + let constraint_usages = unsafe { + slice::from_raw_parts_mut((*self.0).aConstraintUsage, (*self.0).nConstraint as usize) + }; + IndexConstraintAndUsageIter { + iter: constraints.iter().zip(constraint_usages.iter_mut()), + } + } + + /// Record WHERE clause constraints. + #[inline] + #[must_use] + pub fn constraints(&self) -> IndexConstraintIter<'_> { + let constraints = + unsafe { slice::from_raw_parts((*self.0).aConstraint, (*self.0).nConstraint as usize) }; + IndexConstraintIter { + iter: constraints.iter(), + } + } + + /// Information about the ORDER BY clause. + #[inline] + #[must_use] + pub fn order_bys(&self) -> OrderByIter<'_> { + let order_bys = + unsafe { slice::from_raw_parts((*self.0).aOrderBy, (*self.0).nOrderBy as usize) }; + OrderByIter { + iter: order_bys.iter(), + } + } + + /// Number of terms in the ORDER BY clause + #[inline] + #[must_use] + pub fn num_of_order_by(&self) -> usize { + unsafe { (*self.0).nOrderBy as usize } + } + + /// Information about what parameters to pass to [`VTabCursor::filter`]. + #[inline] + pub fn constraint_usage(&mut self, constraint_idx: usize) -> IndexConstraintUsage<'_> { + let constraint_usages = unsafe { + slice::from_raw_parts_mut((*self.0).aConstraintUsage, (*self.0).nConstraint as usize) + }; + IndexConstraintUsage(&mut constraint_usages[constraint_idx]) + } + + /// Number used to identify the index + #[inline] + pub fn set_idx_num(&mut self, idx_num: c_int) { + unsafe { + (*self.0).idxNum = idx_num; + } + } + + /// String used to identify the index + pub fn set_idx_str(&mut self, idx_str: &str) { + unsafe { + (*self.0).idxStr = alloc(idx_str); + (*self.0).needToFreeIdxStr = 1; + } + } + + /// True if output is already ordered + #[inline] + pub fn set_order_by_consumed(&mut self, order_by_consumed: bool) { + unsafe { + (*self.0).orderByConsumed = if order_by_consumed { 1 } else { 0 }; + } + } + + /// Estimated cost of using this index + #[inline] + pub fn set_estimated_cost(&mut self, estimated_ost: f64) { + unsafe { + (*self.0).estimatedCost = estimated_ost; + } + } + + /// Estimated number of rows returned. + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.8.2 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + #[inline] + pub fn set_estimated_rows(&mut self, estimated_rows: i64) { + unsafe { + (*self.0).estimatedRows = estimated_rows; + } + } + + /// Mask of SQLITE_INDEX_SCAN_* flags. + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.9.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + #[inline] + pub fn set_idx_flags(&mut self, flags: IndexFlags) { + unsafe { (*self.0).idxFlags = flags.bits() }; + } + + /// Mask of columns used by statement + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.10.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + #[inline] + pub fn col_used(&self) -> u64 { + unsafe { (*self.0).colUsed } + } + + /// Determine the collation for a virtual table constraint + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.22.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn collation(&self, constraint_idx: usize) -> Result<&str> { + use std::ffi::CStr; + let idx = constraint_idx as c_int; + let collation = unsafe { ffi::sqlite3_vtab_collation(self.0, idx) }; + if collation.is_null() { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("{} is out of range", constraint_idx)), + )); + } + Ok(unsafe { CStr::from_ptr(collation) }.to_str()?) + } + + /*/// Determine if a virtual table query is DISTINCT + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.38.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn distinct(&self) -> c_int { + unsafe { ffi::sqlite3_vtab_distinct(self.0) } + } + + /// Constraint values + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.38.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn set_rhs_value(&mut self, constraint_idx: c_int, value: ValueRef) -> Result<()> { + // TODO ValueRef to sqlite3_value + crate::error::check(unsafe { ffi::sqlite3_vtab_rhs_value(self.O, constraint_idx, value) }) + } + + /// Identify and handle IN constraints + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.38.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn set_in_constraint(&mut self, constraint_idx: c_int, b_handle: c_int) -> bool { + unsafe { ffi::sqlite3_vtab_in(self.0, constraint_idx, b_handle) != 0 } + } // TODO sqlite3_vtab_in_first / sqlite3_vtab_in_next https://sqlite.org/c3ref/vtab_in_first.html + */ +} + +/// Iterate on index constraint and its associated usage. +pub struct IndexConstraintAndUsageIter<'a> { + iter: std::iter::Zip< + slice::Iter<'a, ffi::sqlite3_index_constraint>, + slice::IterMut<'a, ffi::sqlite3_index_constraint_usage>, + >, +} + +impl<'a> Iterator for IndexConstraintAndUsageIter<'a> { + type Item = (IndexConstraint<'a>, IndexConstraintUsage<'a>); + + #[inline] + fn next(&mut self) -> Option<(IndexConstraint<'a>, IndexConstraintUsage<'a>)> { + self.iter + .next() + .map(|raw| (IndexConstraint(raw.0), IndexConstraintUsage(raw.1))) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +/// `feature = "vtab"` +pub struct IndexConstraintIter<'a> { + iter: slice::Iter<'a, ffi::sqlite3_index_constraint>, +} + +impl<'a> Iterator for IndexConstraintIter<'a> { + type Item = IndexConstraint<'a>; + + #[inline] + fn next(&mut self) -> Option> { + self.iter.next().map(IndexConstraint) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +/// WHERE clause constraint. +pub struct IndexConstraint<'a>(&'a ffi::sqlite3_index_constraint); + +impl IndexConstraint<'_> { + /// Column constrained. -1 for ROWID + #[inline] + #[must_use] + pub fn column(&self) -> c_int { + self.0.iColumn + } + + /// Constraint operator + #[inline] + #[must_use] + pub fn operator(&self) -> IndexConstraintOp { + IndexConstraintOp::from(self.0.op) + } + + /// True if this constraint is usable + #[inline] + #[must_use] + pub fn is_usable(&self) -> bool { + self.0.usable != 0 + } +} + +/// Information about what parameters to pass to +/// [`VTabCursor::filter`]. +pub struct IndexConstraintUsage<'a>(&'a mut ffi::sqlite3_index_constraint_usage); + +impl IndexConstraintUsage<'_> { + /// if `argv_index` > 0, constraint is part of argv to + /// [`VTabCursor::filter`] + #[inline] + pub fn set_argv_index(&mut self, argv_index: c_int) { + self.0.argvIndex = argv_index; + } + + /// if `omit`, do not code a test for this constraint + #[inline] + pub fn set_omit(&mut self, omit: bool) { + self.0.omit = if omit { 1 } else { 0 }; + } +} + +/// `feature = "vtab"` +pub struct OrderByIter<'a> { + iter: slice::Iter<'a, ffi::sqlite3_index_info_sqlite3_index_orderby>, +} + +impl<'a> Iterator for OrderByIter<'a> { + type Item = OrderBy<'a>; + + #[inline] + fn next(&mut self) -> Option> { + self.iter.next().map(OrderBy) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +/// A column of the ORDER BY clause. +pub struct OrderBy<'a>(&'a ffi::sqlite3_index_info_sqlite3_index_orderby); + +impl OrderBy<'_> { + /// Column number + #[inline] + #[must_use] + pub fn column(&self) -> c_int { + self.0.iColumn + } + + /// True for DESC. False for ASC. + #[inline] + #[must_use] + pub fn is_order_by_desc(&self) -> bool { + self.0.desc != 0 + } +} + +/// Virtual table cursor trait. +/// +/// # Safety +/// +/// Implementations must be like: +/// ```rust,ignore +/// #[repr(C)] +/// struct MyTabCursor { +/// /// Base class. Must be first +/// base: rusqlite::vtab::sqlite3_vtab_cursor, +/// /* Virtual table implementations will typically add additional fields */ +/// } +/// ``` +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab_cursor.html)) +pub unsafe trait VTabCursor: Sized { + /// Begin a search of a virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xfilter_method)) + fn filter(&mut self, idx_num: c_int, idx_str: Option<&str>, args: &Values<'_>) -> Result<()>; + /// Advance cursor to the next row of a result set initiated by + /// [`filter`](VTabCursor::filter). (See [SQLite doc](https://sqlite.org/vtab.html#the_xnext_method)) + fn next(&mut self) -> Result<()>; + /// Must return `false` if the cursor currently points to a valid row of + /// data, or `true` otherwise. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xeof_method)) + fn eof(&self) -> bool; + /// Find the value for the `i`-th column of the current row. + /// `i` is zero-based so the first column is numbered 0. + /// May return its result back to SQLite using one of the specified `ctx`. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xcolumn_method)) + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()>; + /// Return the rowid of row that the cursor is currently pointing at. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xrowid_method)) + fn rowid(&self) -> Result; +} + +/// Context is used by [`VTabCursor::column`] to specify the +/// cell value. +pub struct Context(*mut ffi::sqlite3_context); + +impl Context { + /// Set current cell value + #[inline] + pub fn set_result(&mut self, value: &T) -> Result<()> { + let t = value.to_sql()?; + unsafe { set_result(self.0, &t) }; + Ok(()) + } + + // TODO sqlite3_vtab_nochange (http://sqlite.org/c3ref/vtab_nochange.html) // 3.22.0 & xColumn +} + +/// Wrapper to [`VTabCursor::filter`] arguments, the values +/// requested by [`VTab::best_index`]. +pub struct Values<'a> { + args: &'a [*mut ffi::sqlite3_value], +} + +impl Values<'_> { + /// Returns the number of values. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.args.len() + } + + /// Returns `true` if there is no value. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.args.is_empty() + } + + /// Returns value at `idx` + pub fn get(&self, idx: usize) -> Result { + let arg = self.args[idx]; + let value = unsafe { ValueRef::from_value(arg) }; + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => Error::InvalidFilterParameterType(idx, value.data_type()), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx, value.data_type(), err) + } + FromSqlError::InvalidBlobSize { .. } => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + }) + } + + // `sqlite3_value_type` returns `SQLITE_NULL` for pointer. + // So it seems not possible to enhance `ValueRef::from_value`. + #[cfg(feature = "array")] + #[cfg_attr(docsrs, doc(cfg(feature = "array")))] + fn get_array(&self, idx: usize) -> Option { + use crate::types::Value; + let arg = self.args[idx]; + let ptr = unsafe { ffi::sqlite3_value_pointer(arg, array::ARRAY_TYPE) }; + if ptr.is_null() { + None + } else { + Some(unsafe { + let rc = array::Array::from_raw(ptr as *const Vec); + let array = rc.clone(); + array::Array::into_raw(rc); // don't consume it + array + }) + } + } + + /// Turns `Values` into an iterator. + #[inline] + #[must_use] + pub fn iter(&self) -> ValueIter<'_> { + ValueIter { + iter: self.args.iter(), + } + } + // TODO sqlite3_vtab_in_first / sqlite3_vtab_in_next https://sqlite.org/c3ref/vtab_in_first.html & 3.38.0 +} + +impl<'a> IntoIterator for &'a Values<'a> { + type IntoIter = ValueIter<'a>; + type Item = ValueRef<'a>; + + #[inline] + fn into_iter(self) -> ValueIter<'a> { + self.iter() + } +} + +/// [`Values`] iterator. +pub struct ValueIter<'a> { + iter: slice::Iter<'a, *mut ffi::sqlite3_value>, +} + +impl<'a> Iterator for ValueIter<'a> { + type Item = ValueRef<'a>; + + #[inline] + fn next(&mut self) -> Option> { + self.iter + .next() + .map(|&raw| unsafe { ValueRef::from_value(raw) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl Connection { + /// Register a virtual table implementation. + /// + /// Step 3 of [Creating New Virtual Table + /// Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). + #[inline] + pub fn create_module<'vtab, T: VTab<'vtab>>( + &self, + module_name: &str, + module: &'static Module<'vtab, T>, + aux: Option, + ) -> Result<()> { + self.db.borrow_mut().create_module(module_name, module, aux) + } +} + +impl InnerConnection { + fn create_module<'vtab, T: VTab<'vtab>>( + &mut self, + module_name: &str, + module: &'static Module<'vtab, T>, + aux: Option, + ) -> Result<()> { + use crate::version; + if version::version_number() < 3_009_000 && module.base.xCreate.is_none() { + return Err(Error::ModuleError(format!( + "Eponymous-only virtual table not supported by SQLite version {}", + version::version() + ))); + } + let c_name = str_to_cstring(module_name)?; + let r = match aux { + Some(aux) => { + let boxed_aux: *mut T::Aux = Box::into_raw(Box::new(aux)); + unsafe { + ffi::sqlite3_create_module_v2( + self.db(), + c_name.as_ptr(), + &module.base, + boxed_aux.cast::(), + Some(free_boxed_value::), + ) + } + } + None => unsafe { + ffi::sqlite3_create_module_v2( + self.db(), + c_name.as_ptr(), + &module.base, + ptr::null_mut(), + None, + ) + }, + }; + self.decode_result(r) + } +} + +/// Escape double-quote (`"`) character occurrences by +/// doubling them (`""`). +#[must_use] +pub fn escape_double_quote(identifier: &str) -> Cow<'_, str> { + if identifier.contains('"') { + // escape quote by doubling them + Owned(identifier.replace('"', "\"\"")) + } else { + Borrowed(identifier) + } +} +/// Dequote string +#[must_use] +pub fn dequote(s: &str) -> &str { + if s.len() < 2 { + return s; + } + match s.bytes().next() { + Some(b) if b == b'"' || b == b'\'' => match s.bytes().rev().next() { + Some(e) if e == b => &s[1..s.len() - 1], // FIXME handle inner escaped quote(s) + _ => s, + }, + _ => s, + } +} +/// The boolean can be one of: +/// ```text +/// 1 yes true on +/// 0 no false off +/// ``` +#[must_use] +pub fn parse_boolean(s: &str) -> Option { + if s.eq_ignore_ascii_case("yes") + || s.eq_ignore_ascii_case("on") + || s.eq_ignore_ascii_case("true") + || s.eq("1") + { + Some(true) + } else if s.eq_ignore_ascii_case("no") + || s.eq_ignore_ascii_case("off") + || s.eq_ignore_ascii_case("false") + || s.eq("0") + { + Some(false) + } else { + None + } +} + +/// `=['"]?['"]?` => `(, )` +pub fn parameter(c_slice: &[u8]) -> Result<(&str, &str)> { + let arg = std::str::from_utf8(c_slice)?.trim(); + let mut split = arg.split('='); + if let Some(key) = split.next() { + if let Some(value) = split.next() { + let param = key.trim(); + let value = dequote(value); + return Ok((param, value)); + } + } + Err(Error::ModuleError(format!("illegal argument: '{}'", arg))) +} + +// FIXME copy/paste from function.rs +unsafe extern "C" fn free_boxed_value(p: *mut c_void) { + drop(Box::from_raw(p.cast::())); +} + +unsafe extern "C" fn rust_create<'vtab, T>( + db: *mut ffi::sqlite3, + aux: *mut c_void, + argc: c_int, + argv: *const *const c_char, + pp_vtab: *mut *mut ffi::sqlite3_vtab, + err_msg: *mut *mut c_char, +) -> c_int +where + T: CreateVTab<'vtab>, +{ + use std::ffi::CStr; + + let mut conn = VTabConnection(db); + let aux = aux.cast::(); + let args = slice::from_raw_parts(argv, argc as usize); + let vec = args + .iter() + .map(|&cs| CStr::from_ptr(cs).to_bytes()) // FIXME .to_str() -> Result<&str, Utf8Error> + .collect::>(); + match T::create(&mut conn, aux.as_ref(), &vec[..]) { + Ok((sql, vtab)) => match std::ffi::CString::new(sql) { + Ok(c_sql) => { + let rc = ffi::sqlite3_declare_vtab(db, c_sql.as_ptr()); + if rc == ffi::SQLITE_OK { + let boxed_vtab: *mut T = Box::into_raw(Box::new(vtab)); + *pp_vtab = boxed_vtab.cast::(); + ffi::SQLITE_OK + } else { + let err = error_from_sqlite_code(rc, None); + *err_msg = alloc(&err.to_string()); + rc + } + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + }, + Err(Error::SqliteFailure(err, s)) => { + if let Some(s) = s { + *err_msg = alloc(&s); + } + err.extended_code + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_connect<'vtab, T>( + db: *mut ffi::sqlite3, + aux: *mut c_void, + argc: c_int, + argv: *const *const c_char, + pp_vtab: *mut *mut ffi::sqlite3_vtab, + err_msg: *mut *mut c_char, +) -> c_int +where + T: VTab<'vtab>, +{ + use std::ffi::CStr; + + let mut conn = VTabConnection(db); + let aux = aux.cast::(); + let args = slice::from_raw_parts(argv, argc as usize); + let vec = args + .iter() + .map(|&cs| CStr::from_ptr(cs).to_bytes()) // FIXME .to_str() -> Result<&str, Utf8Error> + .collect::>(); + match T::connect(&mut conn, aux.as_ref(), &vec[..]) { + Ok((sql, vtab)) => match std::ffi::CString::new(sql) { + Ok(c_sql) => { + let rc = ffi::sqlite3_declare_vtab(db, c_sql.as_ptr()); + if rc == ffi::SQLITE_OK { + let boxed_vtab: *mut T = Box::into_raw(Box::new(vtab)); + *pp_vtab = boxed_vtab.cast::(); + ffi::SQLITE_OK + } else { + let err = error_from_sqlite_code(rc, None); + *err_msg = alloc(&err.to_string()); + rc + } + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + }, + Err(Error::SqliteFailure(err, s)) => { + if let Some(s) = s { + *err_msg = alloc(&s); + } + err.extended_code + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_best_index<'vtab, T>( + vtab: *mut ffi::sqlite3_vtab, + info: *mut ffi::sqlite3_index_info, +) -> c_int +where + T: VTab<'vtab>, +{ + let vt = vtab.cast::(); + let mut idx_info = IndexInfo(info); + match (*vt).best_index(&mut idx_info) { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_disconnect<'vtab, T>(vtab: *mut ffi::sqlite3_vtab) -> c_int +where + T: VTab<'vtab>, +{ + if vtab.is_null() { + return ffi::SQLITE_OK; + } + let vtab = vtab.cast::(); + drop(Box::from_raw(vtab)); + ffi::SQLITE_OK +} + +unsafe extern "C" fn rust_destroy<'vtab, T>(vtab: *mut ffi::sqlite3_vtab) -> c_int +where + T: CreateVTab<'vtab>, +{ + if vtab.is_null() { + return ffi::SQLITE_OK; + } + let vt = vtab.cast::(); + match (*vt).destroy() { + Ok(_) => { + drop(Box::from_raw(vt)); + ffi::SQLITE_OK + } + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_open<'vtab, T: 'vtab>( + vtab: *mut ffi::sqlite3_vtab, + pp_cursor: *mut *mut ffi::sqlite3_vtab_cursor, +) -> c_int +where + T: VTab<'vtab>, +{ + let vt = vtab.cast::(); + match (*vt).open() { + Ok(cursor) => { + let boxed_cursor: *mut T::Cursor = Box::into_raw(Box::new(cursor)); + *pp_cursor = boxed_cursor.cast::(); + ffi::SQLITE_OK + } + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_close(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor.cast::(); + drop(Box::from_raw(cr)); + ffi::SQLITE_OK +} + +unsafe extern "C" fn rust_filter( + cursor: *mut ffi::sqlite3_vtab_cursor, + idx_num: c_int, + idx_str: *const c_char, + argc: c_int, + argv: *mut *mut ffi::sqlite3_value, +) -> c_int +where + C: VTabCursor, +{ + use std::ffi::CStr; + use std::str; + let idx_name = if idx_str.is_null() { + None + } else { + let c_slice = CStr::from_ptr(idx_str).to_bytes(); + Some(str::from_utf8_unchecked(c_slice)) + }; + let args = slice::from_raw_parts_mut(argv, argc as usize); + let values = Values { args }; + let cr = cursor as *mut C; + cursor_error(cursor, (*cr).filter(idx_num, idx_name, &values)) +} + +unsafe extern "C" fn rust_next(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor as *mut C; + cursor_error(cursor, (*cr).next()) +} + +unsafe extern "C" fn rust_eof(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor.cast::(); + (*cr).eof() as c_int +} + +unsafe extern "C" fn rust_column( + cursor: *mut ffi::sqlite3_vtab_cursor, + ctx: *mut ffi::sqlite3_context, + i: c_int, +) -> c_int +where + C: VTabCursor, +{ + let cr = cursor.cast::(); + let mut ctxt = Context(ctx); + result_error(ctx, (*cr).column(&mut ctxt, i)) +} + +unsafe extern "C" fn rust_rowid( + cursor: *mut ffi::sqlite3_vtab_cursor, + p_rowid: *mut ffi::sqlite3_int64, +) -> c_int +where + C: VTabCursor, +{ + let cr = cursor.cast::(); + match (*cr).rowid() { + Ok(rowid) => { + *p_rowid = rowid; + ffi::SQLITE_OK + } + err => cursor_error(cursor, err), + } +} + +unsafe extern "C" fn rust_update<'vtab, T: 'vtab>( + vtab: *mut ffi::sqlite3_vtab, + argc: c_int, + argv: *mut *mut ffi::sqlite3_value, + p_rowid: *mut ffi::sqlite3_int64, +) -> c_int +where + T: UpdateVTab<'vtab>, +{ + assert!(argc >= 1); + let args = slice::from_raw_parts_mut(argv, argc as usize); + let vt = vtab.cast::(); + let r = if args.len() == 1 { + (*vt).delete(ValueRef::from_value(args[0])) + } else if ffi::sqlite3_value_type(args[0]) == ffi::SQLITE_NULL { + // TODO Make the distinction between argv[1] == NULL and argv[1] != NULL ? + let values = Values { args }; + match (*vt).insert(&values) { + Ok(rowid) => { + *p_rowid = rowid; + Ok(()) + } + Err(e) => Err(e), + } + } else { + let values = Values { args }; + (*vt).update(&values) + }; + match r { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +/// Virtual table cursors can set an error message by assigning a string to +/// `zErrMsg`. +#[cold] +unsafe fn cursor_error(cursor: *mut ffi::sqlite3_vtab_cursor, result: Result) -> c_int { + match result { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg((*cursor).pVtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg((*cursor).pVtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +/// Virtual tables methods can set an error message by assigning a string to +/// `zErrMsg`. +#[cold] +unsafe fn set_err_msg(vtab: *mut ffi::sqlite3_vtab, err_msg: &str) { + if !(*vtab).zErrMsg.is_null() { + ffi::sqlite3_free((*vtab).zErrMsg.cast::()); + } + (*vtab).zErrMsg = alloc(err_msg); +} + +/// To raise an error, the `column` method should use this method to set the +/// error message and return the error code. +#[cold] +unsafe fn result_error(ctx: *mut ffi::sqlite3_context, result: Result) -> c_int { + match result { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + match err.extended_code { + ffi::SQLITE_TOOBIG => { + ffi::sqlite3_result_error_toobig(ctx); + } + ffi::SQLITE_NOMEM => { + ffi::sqlite3_result_error_nomem(ctx); + } + code => { + ffi::sqlite3_result_error_code(ctx, code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + }; + err.extended_code + } + Err(err) => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_ERROR); + if let Ok(cstr) = str_to_cstring(&err.to_string()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + ffi::SQLITE_ERROR + } + } +} + +// Space to hold this string must be obtained +// from an SQLite memory allocation function +fn alloc(s: &str) -> *mut c_char { + crate::util::SqliteMallocString::from_str(s).into_raw() +} + +#[cfg(feature = "array")] +#[cfg_attr(docsrs, doc(cfg(feature = "array")))] +pub mod array; +#[cfg(feature = "csvtab")] +#[cfg_attr(docsrs, doc(cfg(feature = "csvtab")))] +pub mod csvtab; +#[cfg(feature = "series")] +#[cfg_attr(docsrs, doc(cfg(feature = "series")))] +pub mod series; // SQLite >= 3.9.0 +#[cfg(test)] +mod vtablog; + +#[cfg(test)] +mod test { + #[test] + fn test_dequote() { + assert_eq!("", super::dequote("")); + assert_eq!("'", super::dequote("'")); + assert_eq!("\"", super::dequote("\"")); + assert_eq!("'\"", super::dequote("'\"")); + assert_eq!("", super::dequote("''")); + assert_eq!("", super::dequote("\"\"")); + assert_eq!("x", super::dequote("'x'")); + assert_eq!("x", super::dequote("\"x\"")); + assert_eq!("x", super::dequote("x")); + } + #[test] + fn test_parse_boolean() { + assert_eq!(None, super::parse_boolean("")); + assert_eq!(Some(true), super::parse_boolean("1")); + assert_eq!(Some(true), super::parse_boolean("yes")); + assert_eq!(Some(true), super::parse_boolean("on")); + assert_eq!(Some(true), super::parse_boolean("true")); + assert_eq!(Some(false), super::parse_boolean("0")); + assert_eq!(Some(false), super::parse_boolean("no")); + assert_eq!(Some(false), super::parse_boolean("off")); + assert_eq!(Some(false), super::parse_boolean("false")); + } +} diff --git a/src/vtab/series.rs b/src/vtab/series.rs new file mode 100644 index 0000000..fffbd4d --- /dev/null +++ b/src/vtab/series.rs @@ -0,0 +1,319 @@ +//! Generate series virtual table. +//! +//! Port of C [generate series +//! "function"](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/series.c): +//! `https://www.sqlite.org/series.html` +use std::default::Default; +use std::marker::PhantomData; +use std::os::raw::c_int; + +use crate::ffi; +use crate::types::Type; +use crate::vtab::{ + eponymous_only_module, Context, IndexConstraintOp, IndexInfo, VTab, VTabConfig, VTabConnection, + VTabCursor, Values, +}; +use crate::{Connection, Error, Result}; + +/// Register the "generate_series" module. +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("generate_series", eponymous_only_module::(), aux) +} + +// Column numbers +// const SERIES_COLUMN_VALUE : c_int = 0; +const SERIES_COLUMN_START: c_int = 1; +const SERIES_COLUMN_STOP: c_int = 2; +const SERIES_COLUMN_STEP: c_int = 3; + +bitflags::bitflags! { + #[repr(C)] + struct QueryPlanFlags: ::std::os::raw::c_int { + // start = $value -- constraint exists + const START = 1; + // stop = $value -- constraint exists + const STOP = 2; + // step = $value -- constraint exists + const STEP = 4; + // output in descending order + const DESC = 8; + // output in ascending order + const ASC = 16; + // Both start and stop + const BOTH = QueryPlanFlags::START.bits | QueryPlanFlags::STOP.bits; + } +} + +/// An instance of the Series virtual table +#[repr(C)] +struct SeriesTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for SeriesTab { + type Aux = (); + type Cursor = SeriesTabCursor<'vtab>; + + fn connect( + db: &mut VTabConnection, + _aux: Option<&()>, + _args: &[&[u8]], + ) -> Result<(String, SeriesTab)> { + let vtab = SeriesTab { + base: ffi::sqlite3_vtab::default(), + }; + db.config(VTabConfig::Innocuous)?; + Ok(( + "CREATE TABLE x(value,start hidden,stop hidden,step hidden)".to_owned(), + vtab, + )) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + // The query plan bitmask + let mut idx_num: QueryPlanFlags = QueryPlanFlags::empty(); + // Mask of unusable constraints + let mut unusable_mask: QueryPlanFlags = QueryPlanFlags::empty(); + // Constraints on start, stop, and step + let mut a_idx: [Option; 3] = [None, None, None]; + for (i, constraint) in info.constraints().enumerate() { + if constraint.column() < SERIES_COLUMN_START { + continue; + } + let (i_col, i_mask) = match constraint.column() { + SERIES_COLUMN_START => (0, QueryPlanFlags::START), + SERIES_COLUMN_STOP => (1, QueryPlanFlags::STOP), + SERIES_COLUMN_STEP => (2, QueryPlanFlags::STEP), + _ => { + unreachable!() + } + }; + if !constraint.is_usable() { + unusable_mask |= i_mask; + } else if constraint.operator() == IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ { + idx_num |= i_mask; + a_idx[i_col] = Some(i); + } + } + // Number of arguments that SeriesTabCursor::filter expects + let mut n_arg = 0; + for j in a_idx.iter().flatten() { + n_arg += 1; + let mut constraint_usage = info.constraint_usage(*j); + constraint_usage.set_argv_index(n_arg); + constraint_usage.set_omit(true); + #[cfg(all(test, feature = "modern_sqlite"))] + debug_assert_eq!(Ok("BINARY"), info.collation(*j)); + } + if !(unusable_mask & !idx_num).is_empty() { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_CONSTRAINT), + None, + )); + } + if idx_num.contains(QueryPlanFlags::BOTH) { + // Both start= and stop= boundaries are available. + info.set_estimated_cost(f64::from( + 2 - if idx_num.contains(QueryPlanFlags::STEP) { + 1 + } else { + 0 + }, + )); + info.set_estimated_rows(1000); + let order_by_consumed = { + let mut order_bys = info.order_bys(); + if let Some(order_by) = order_bys.next() { + if order_by.column() == 0 { + if order_by.is_order_by_desc() { + idx_num |= QueryPlanFlags::DESC; + } else { + idx_num |= QueryPlanFlags::ASC; + } + true + } else { + false + } + } else { + false + } + }; + if order_by_consumed { + info.set_order_by_consumed(true); + } + } else { + // If either boundary is missing, we have to generate a huge span + // of numbers. Make this case very expensive so that the query + // planner will work hard to avoid it. + info.set_estimated_rows(2_147_483_647); + } + info.set_idx_num(idx_num.bits()); + Ok(()) + } + + fn open(&mut self) -> Result> { + Ok(SeriesTabCursor::new()) + } +} + +/// A cursor for the Series virtual table +#[repr(C)] +struct SeriesTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// True to count down rather than up + is_desc: bool, + /// The rowid + row_id: i64, + /// Current value ("value") + value: i64, + /// Minimum value ("start") + min_value: i64, + /// Maximum value ("stop") + max_value: i64, + /// Increment ("step") + step: i64, + phantom: PhantomData<&'vtab SeriesTab>, +} + +impl SeriesTabCursor<'_> { + fn new<'vtab>() -> SeriesTabCursor<'vtab> { + SeriesTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + is_desc: false, + row_id: 0, + value: 0, + min_value: 0, + max_value: 0, + step: 0, + phantom: PhantomData, + } + } +} +#[allow(clippy::comparison_chain)] +unsafe impl VTabCursor for SeriesTabCursor<'_> { + fn filter(&mut self, idx_num: c_int, _idx_str: Option<&str>, args: &Values<'_>) -> Result<()> { + let mut idx_num = QueryPlanFlags::from_bits_truncate(idx_num); + let mut i = 0; + if idx_num.contains(QueryPlanFlags::START) { + self.min_value = args.get(i)?; + i += 1; + } else { + self.min_value = 0; + } + if idx_num.contains(QueryPlanFlags::STOP) { + self.max_value = args.get(i)?; + i += 1; + } else { + self.max_value = 0xffff_ffff; + } + if idx_num.contains(QueryPlanFlags::STEP) { + self.step = args.get(i)?; + if self.step == 0 { + self.step = 1; + } else if self.step < 0 { + self.step = -self.step; + if !idx_num.contains(QueryPlanFlags::ASC) { + idx_num |= QueryPlanFlags::DESC; + } + } + } else { + self.step = 1; + }; + for arg in args.iter() { + if arg.data_type() == Type::Null { + // If any of the constraints have a NULL value, then return no rows. + self.min_value = 1; + self.max_value = 0; + break; + } + } + self.is_desc = idx_num.contains(QueryPlanFlags::DESC); + if self.is_desc { + self.value = self.max_value; + if self.step > 0 { + self.value -= (self.max_value - self.min_value) % self.step; + } + } else { + self.value = self.min_value; + } + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + if self.is_desc { + self.value -= self.step; + } else { + self.value += self.step; + } + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + if self.is_desc { + self.value < self.min_value + } else { + self.value > self.max_value + } + } + + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()> { + let x = match i { + SERIES_COLUMN_START => self.min_value, + SERIES_COLUMN_STOP => self.max_value, + SERIES_COLUMN_STEP => self.step, + _ => self.value, + }; + ctx.set_result(&x) + } + + fn rowid(&self) -> Result { + Ok(self.row_id) + } +} + +#[cfg(test)] +mod test { + use crate::ffi; + use crate::vtab::series; + use crate::{Connection, Result}; + use fallible_iterator::FallibleIterator; + + #[test] + fn test_series_module() -> Result<()> { + let version = unsafe { ffi::sqlite3_libversion_number() }; + if version < 3_008_012 { + return Ok(()); + } + + let db = Connection::open_in_memory()?; + series::load_module(&db)?; + + let mut s = db.prepare("SELECT * FROM generate_series(0,20,5)")?; + + let series = s.query_map([], |row| row.get::<_, i32>(0))?; + + let mut expected = 0; + for value in series { + assert_eq!(expected, value?); + expected += 5; + } + + let mut s = + db.prepare("SELECT * FROM generate_series WHERE start=1 AND stop=9 AND step=2")?; + let series: Vec = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(vec![1, 3, 5, 7, 9], series); + let mut s = db.prepare("SELECT * FROM generate_series LIMIT 5")?; + let series: Vec = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(vec![0, 1, 2, 3, 4], series); + let mut s = db.prepare("SELECT * FROM generate_series(0,32,5) ORDER BY value DESC")?; + let series: Vec = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(vec![30, 25, 20, 15, 10, 5, 0], series); + + Ok(()) + } +} diff --git a/src/vtab/vtablog.rs b/src/vtab/vtablog.rs new file mode 100644 index 0000000..bc2e01f --- /dev/null +++ b/src/vtab/vtablog.rs @@ -0,0 +1,300 @@ +///! Port of C [vtablog](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/vtablog.c) +use std::default::Default; +use std::marker::PhantomData; +use std::os::raw::c_int; +use std::str::FromStr; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crate::vtab::{ + update_module, Context, CreateVTab, IndexInfo, UpdateVTab, VTab, VTabConnection, VTabCursor, + VTabKind, Values, +}; +use crate::{ffi, ValueRef}; +use crate::{Connection, Error, Result}; + +/// Register the "vtablog" module. +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("vtablog", update_module::(), aux) +} + +/// An instance of the vtablog virtual table +#[repr(C)] +struct VTabLog { + /// Base class. Must be first + base: ffi::sqlite3_vtab, + /// Number of rows in the table + n_row: i64, + /// Instance number for this vtablog table + i_inst: usize, + /// Number of cursors created + n_cursor: usize, +} + +impl VTabLog { + fn connect_create( + _: &mut VTabConnection, + _: Option<&()>, + args: &[&[u8]], + is_create: bool, + ) -> Result<(String, VTabLog)> { + static N_INST: AtomicUsize = AtomicUsize::new(1); + let i_inst = N_INST.fetch_add(1, Ordering::SeqCst); + println!( + "VTabLog::{}(tab={}, args={:?}):", + if is_create { "create" } else { "connect" }, + i_inst, + args, + ); + let mut schema = None; + let mut n_row = None; + + let args = &args[3..]; + for c_slice in args { + let (param, value) = super::parameter(c_slice)?; + match param { + "schema" => { + if schema.is_some() { + return Err(Error::ModuleError(format!( + "more than one '{}' parameter", + param + ))); + } + schema = Some(value.to_owned()) + } + "rows" => { + if n_row.is_some() { + return Err(Error::ModuleError(format!( + "more than one '{}' parameter", + param + ))); + } + if let Ok(n) = i64::from_str(value) { + n_row = Some(n) + } + } + _ => { + return Err(Error::ModuleError(format!( + "unrecognized parameter '{}'", + param + ))); + } + } + } + if schema.is_none() { + return Err(Error::ModuleError("no schema defined".to_owned())); + } + let vtab = VTabLog { + base: ffi::sqlite3_vtab::default(), + n_row: n_row.unwrap_or(10), + i_inst, + n_cursor: 0, + }; + Ok((schema.unwrap(), vtab)) + } +} + +impl Drop for VTabLog { + fn drop(&mut self) { + println!("VTabLog::drop({})", self.i_inst); + } +} + +unsafe impl<'vtab> VTab<'vtab> for VTabLog { + type Aux = (); + type Cursor = VTabLogCursor<'vtab>; + + fn connect( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)> { + VTabLog::connect_create(db, aux, args, false) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + println!("VTabLog::best_index({})", self.i_inst); + info.set_estimated_cost(500.); + info.set_estimated_rows(500); + Ok(()) + } + + fn open(&'vtab mut self) -> Result { + self.n_cursor += 1; + println!( + "VTabLog::open(tab={}, cursor={})", + self.i_inst, self.n_cursor + ); + Ok(VTabLogCursor { + base: ffi::sqlite3_vtab_cursor::default(), + i_cursor: self.n_cursor, + row_id: 0, + phantom: PhantomData, + }) + } +} + +impl<'vtab> CreateVTab<'vtab> for VTabLog { + const KIND: VTabKind = VTabKind::Default; + + fn create( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)> { + VTabLog::connect_create(db, aux, args, true) + } + + fn destroy(&self) -> Result<()> { + println!("VTabLog::destroy({})", self.i_inst); + Ok(()) + } +} + +impl<'vtab> UpdateVTab<'vtab> for VTabLog { + fn delete(&mut self, arg: ValueRef<'_>) -> Result<()> { + println!("VTabLog::delete({}, {:?})", self.i_inst, arg); + Ok(()) + } + + fn insert(&mut self, args: &Values<'_>) -> Result { + println!( + "VTabLog::insert({}, {:?})", + self.i_inst, + args.iter().collect::>>() + ); + Ok(self.n_row as i64) + } + + fn update(&mut self, args: &Values<'_>) -> Result<()> { + println!( + "VTabLog::update({}, {:?})", + self.i_inst, + args.iter().collect::>>() + ); + Ok(()) + } +} + +/// A cursor for the Series virtual table +#[repr(C)] +struct VTabLogCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// Cursor number + i_cursor: usize, + /// The rowid + row_id: i64, + phantom: PhantomData<&'vtab VTabLog>, +} + +impl VTabLogCursor<'_> { + fn vtab(&self) -> &VTabLog { + unsafe { &*(self.base.pVtab as *const VTabLog) } + } +} + +impl Drop for VTabLogCursor<'_> { + fn drop(&mut self) { + println!( + "VTabLogCursor::drop(tab={}, cursor={})", + self.vtab().i_inst, + self.i_cursor + ); + } +} + +unsafe impl VTabCursor for VTabLogCursor<'_> { + fn filter(&mut self, _: c_int, _: Option<&str>, _: &Values<'_>) -> Result<()> { + println!( + "VTabLogCursor::filter(tab={}, cursor={})", + self.vtab().i_inst, + self.i_cursor + ); + self.row_id = 0; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + println!( + "VTabLogCursor::next(tab={}, cursor={}): rowid {} -> {}", + self.vtab().i_inst, + self.i_cursor, + self.row_id, + self.row_id + 1 + ); + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + let eof = self.row_id >= self.vtab().n_row; + println!( + "VTabLogCursor::eof(tab={}, cursor={}): {}", + self.vtab().i_inst, + self.i_cursor, + eof, + ); + eof + } + + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()> { + let value = if i < 26 { + format!( + "{}{}", + "abcdefghijklmnopqrstuvwyz".chars().nth(i as usize).unwrap(), + self.row_id + ) + } else { + format!("{}{}", i, self.row_id) + }; + println!( + "VTabLogCursor::column(tab={}, cursor={}, i={}): {}", + self.vtab().i_inst, + self.i_cursor, + i, + value, + ); + ctx.set_result(&value) + } + + fn rowid(&self) -> Result { + println!( + "VTabLogCursor::rowid(tab={}, cursor={}): {}", + self.vtab().i_inst, + self.i_cursor, + self.row_id, + ); + Ok(self.row_id) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result}; + #[test] + fn test_module() -> Result<()> { + let db = Connection::open_in_memory()?; + super::load_module(&db)?; + + db.execute_batch( + "CREATE VIRTUAL TABLE temp.log USING vtablog( + schema='CREATE TABLE x(a,b,c)', + rows=25 + );", + )?; + let mut stmt = db.prepare("SELECT * FROM log;")?; + let mut rows = stmt.query([])?; + while rows.next()?.is_some() {} + db.execute("DELETE FROM log WHERE a = ?", ["a1"])?; + db.execute( + "INSERT INTO log (a, b, c) VALUES (?, ?, ?)", + ["a", "b", "c"], + )?; + db.execute( + "UPDATE log SET b = ?, c = ? WHERE a = ?", + ["bn", "cn", "a1"], + )?; + Ok(()) + } +} diff --git a/test.csv b/test.csv new file mode 100644 index 0000000..708f93f --- /dev/null +++ b/test.csv @@ -0,0 +1,6 @@ +"colA","colB","colC" +1,2,3 +a,b,c +a,"b",c +"a","b","c .. z" +"a","b","c,d" diff --git a/tests/config_log.rs b/tests/config_log.rs new file mode 100644 index 0000000..0c28bdf --- /dev/null +++ b/tests/config_log.rs @@ -0,0 +1,34 @@ +//! This file contains unit tests for `rusqlite::trace::config_log`. This +//! function affects SQLite process-wide and so is not safe to run as a normal +//! #[test] in the library. + +#[cfg(feature = "trace")] +fn main() { + use lazy_static::lazy_static; + use std::os::raw::c_int; + use std::sync::Mutex; + + lazy_static! { + static ref LOGS_RECEIVED: Mutex> = Mutex::new(Vec::new()); + } + + fn log_handler(err: c_int, message: &str) { + let mut logs_received = LOGS_RECEIVED.lock().unwrap(); + logs_received.push((err, message.to_owned())); + } + + use rusqlite::trace; + + unsafe { trace::config_log(Some(log_handler)) }.unwrap(); + trace::log(10, "First message from rusqlite"); + unsafe { trace::config_log(None) }.unwrap(); + trace::log(11, "Second message from rusqlite"); + + let logs_received = LOGS_RECEIVED.lock().unwrap(); + assert_eq!(logs_received.len(), 1); + assert_eq!(logs_received[0].0, 10); + assert_eq!(logs_received[0].1, "First message from rusqlite"); +} + +#[cfg(not(feature = "trace"))] +fn main() {} diff --git a/tests/deny_single_threaded_sqlite_config.rs b/tests/deny_single_threaded_sqlite_config.rs new file mode 100644 index 0000000..adfc8e5 --- /dev/null +++ b/tests/deny_single_threaded_sqlite_config.rs @@ -0,0 +1,20 @@ +//! Ensure we reject connections when SQLite is in single-threaded mode, as it +//! would violate safety if multiple Rust threads tried to use connections. + +use rusqlite::ffi; +use rusqlite::Connection; + +#[test] +fn test_error_when_singlethread_mode() { + // put SQLite into single-threaded mode + unsafe { + // Note: macOS system SQLite seems to return an error if you attempt to + // reconfigure to single-threaded mode. + if ffi::sqlite3_config(ffi::SQLITE_CONFIG_SINGLETHREAD) != ffi::SQLITE_OK { + return; + } + assert_eq!(ffi::sqlite3_initialize(), ffi::SQLITE_OK); + } + let res = Connection::open_in_memory(); + assert!(res.is_err()); +} diff --git a/tests/vtab.rs b/tests/vtab.rs new file mode 100644 index 0000000..6558206 --- /dev/null +++ b/tests/vtab.rs @@ -0,0 +1,100 @@ +//! Ensure Virtual tables can be declared outside `rusqlite` crate. + +#[cfg(feature = "vtab")] +#[test] +fn test_dummy_module() -> rusqlite::Result<()> { + use rusqlite::vtab::{ + eponymous_only_module, sqlite3_vtab, sqlite3_vtab_cursor, Context, IndexInfo, VTab, + VTabConnection, VTabCursor, Values, + }; + use rusqlite::{version_number, Connection, Result}; + use std::marker::PhantomData; + use std::os::raw::c_int; + + let module = eponymous_only_module::(); + + #[repr(C)] + struct DummyTab { + /// Base class. Must be first + base: sqlite3_vtab, + } + + unsafe impl<'vtab> VTab<'vtab> for DummyTab { + type Aux = (); + type Cursor = DummyTabCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + _args: &[&[u8]], + ) -> Result<(String, DummyTab)> { + let vtab = DummyTab { + base: sqlite3_vtab::default(), + }; + Ok(("CREATE TABLE x(value)".to_owned(), vtab)) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + info.set_estimated_cost(1.); + Ok(()) + } + + fn open(&'vtab mut self) -> Result> { + Ok(DummyTabCursor::default()) + } + } + + #[derive(Default)] + #[repr(C)] + struct DummyTabCursor<'vtab> { + /// Base class. Must be first + base: sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + phantom: PhantomData<&'vtab DummyTab>, + } + + unsafe impl VTabCursor for DummyTabCursor<'_> { + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> Result<()> { + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.row_id > 1 + } + + fn column(&self, ctx: &mut Context, _: c_int) -> Result<()> { + ctx.set_result(&self.row_id) + } + + fn rowid(&self) -> Result { + Ok(self.row_id) + } + } + + let db = Connection::open_in_memory()?; + + db.create_module::("dummy", module, None)?; + + let version = version_number(); + if version < 3_009_000 { + return Ok(()); + } + + let mut s = db.prepare("SELECT * FROM dummy()")?; + + let dummy = s.query_row([], |row| row.get::<_, i32>(0))?; + assert_eq!(1, dummy); + Ok(()) +} -- 2.7.4