// Copyright (c) 2023-2025 ParadeDB, Inc.
//
// This file is part of ParadeDB - Postgres for Search and Analytics
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
use anyhow::Result;
use async_std::prelude::Stream;
use async_std::stream::StreamExt;
use async_std::task::block_on;
use bytes::Bytes;
use rand::Rng;
use sqlx::{
postgres::PgRow,
testing::{TestArgs, TestContext, TestSupport},
ConnectOptions, Connection, Decode, Error, Executor, FromRow, PgConnection, Postgres, Type,
};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub struct Db {
context: TestContext,
}
impl Db {
pub async fn new() -> Self {
let path =
// timestamp
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("current time should be retrievable")
.as_micros()
.to_string()
// plus the current thread name, which is typically going to be the test name
+ &std::thread::current()
.name()
.map(String::from)
.unwrap_or_else(|| {
// or a random 7-letter "word"
rand::rng()
.sample_iter(&rand::distr::Alphanumeric)
.take(7)
.map(char::from)
.collect()
});
let args = TestArgs::new(Box::leak(path.into_boxed_str()));
let context = Postgres::test_context(&args)
.await
.unwrap_or_else(|err| panic!("could not create test database: {err:#?}"));
Self { context }
}
pub async fn connection(&self) -> PgConnection {
self.context
.connect_opts
.connect()
.await
.unwrap_or_else(|err| panic!("failed to connect to test database: {err:#?}"))
}
}
impl Drop for Db {
fn drop(&mut self) {
let db_name = self.context.db_name.to_string();
async_std::task::spawn(async move {
Postgres::cleanup_test(db_name.as_str()).await.ok(); // ignore errors as there's nothing we can do about it
});
}
}
pub trait ConnExt {
fn deallocate_all(&mut self) -> Result<(), sqlx::Error>;
}
impl ConnExt for PgConnection {
/// Deallocate all cached prepared statements. Akin to Postgres' `DEALLOCATE ALL` command
/// but also does the right thing for the sql [`PgConnection`] internals.
fn deallocate_all(&mut self) -> Result<(), Error> {
async_std::task::block_on(async { self.clear_cached_statements().await })
}
}
#[allow(dead_code)]
pub trait Query
where
Self: AsRef + Sized,
{
fn execute(self, connection: &mut PgConnection) {
block_on(async { self.execute_async(connection).await })
}
#[allow(async_fn_in_trait)]
async fn execute_async(self, connection: &mut PgConnection) {
connection
.execute(self.as_ref())
.await
.expect("query execution should succeed");
}
fn execute_result(self, connection: &mut PgConnection) -> Result<(), sqlx::Error> {
block_on(async { connection.execute(self.as_ref()).await })?;
Ok(())
}
fn fetch(self, connection: &mut PgConnection) -> Vec
where
T: for<'r> FromRow<'r, ::Row> + Send + Unpin,
{
block_on(async {
sqlx::query_as::<_, T>(self.as_ref())
.fetch_all(connection)
.await
.unwrap_or_else(|e| panic!("{e}: error in query '{}'", self.as_ref()))
})
}
fn fetch_retry(
self,
connection: &mut PgConnection,
retries: u32,
delay_ms: u64,
validate: fn(&[T]) -> bool,
) -> Vec
where
T: for<'r> FromRow<'r, ::Row> + Send + Unpin,
{
for attempt in 0..retries {
match block_on(async {
sqlx::query_as::<_, T>(self.as_ref())
.fetch_all(&mut *connection)
.await
.map_err(anyhow::Error::from)
}) {
Ok(result) => {
if validate(&result) {
return result;
} else if attempt < retries - 1 {
block_on(async_std::task::sleep(Duration::from_millis(delay_ms)));
} else {
return vec![];
}
}
Err(_) if attempt < retries - 1 => {
block_on(async_std::task::sleep(Duration::from_millis(delay_ms)));
}
Err(e) => panic!("Fetch attempt {}/{} failed: {}", attempt + 1, retries, e),
}
}
panic!("Exhausted retries for query '{}'", self.as_ref());
}
fn fetch_dynamic(self, connection: &mut PgConnection) -> Vec {
block_on(async {
sqlx::query(self.as_ref())
.fetch_all(connection)
.await
.unwrap_or_else(|e| panic!("{e}: error in query '{}'", self.as_ref()))
})
}
fn fetch_scalar(self, connection: &mut PgConnection) -> Vec
where
T: Type + for<'a> Decode<'a, sqlx::Postgres> + Send + Unpin,
{
block_on(async {
sqlx::query_scalar(self.as_ref())
.fetch_all(connection)
.await
.unwrap_or_else(|e| panic!("{e}: error in query '{}'", self.as_ref()))
})
}
fn fetch_one(self, connection: &mut PgConnection) -> T
where
T: for<'r> FromRow<'r, ::Row> + Send + Unpin,
{
block_on(async {
sqlx::query_as::<_, T>(self.as_ref())
.fetch_one(connection)
.await
.unwrap_or_else(|e| panic!("{e}: error in query '{}'", self.as_ref()))
})
}
fn fetch_result(self, connection: &mut PgConnection) -> Result, sqlx::Error>
where
T: for<'r> FromRow<'r, ::Row> + Send + Unpin,
{
block_on(async {
sqlx::query_as::<_, T>(self.as_ref())
.fetch_all(connection)
.await
})
}
fn fetch_collect(self, connection: &mut PgConnection) -> B
where
T: for<'r> FromRow<'r, ::Row> + Send + Unpin,
B: FromIterator,
{
self.fetch(connection).into_iter().collect::()
}
}
impl Query for String {}
impl Query for &String {}
impl Query for &str {}
pub trait DisplayAsync: Stream- > + Sized {
fn to_csv(self) -> String {
let mut csv_str = String::new();
let mut stream = Box::pin(self);
while let Some(chunk) = block_on(stream.as_mut().next()) {
let chunk = chunk.expect("chunk should be valid for DisplayAsync");
csv_str.push_str(&String::from_utf8_lossy(&chunk));
}
csv_str
}
}
impl DisplayAsync for T where T: Stream
- > + Send + Sized {}