// Copyright (c) 2023-2025 ParadeDB, Inc.
//
// This file is part of ParadeDB - Postgres for Search and Analytics
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
use crate::sqlscanner::{ScannedStatement, SqlStatementScanner, StatementDestination};
use pgrx_pg_config::{PgConfig, Pgrx};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::Arc;
// See https://users.rust-lang.org/t/concatenate-two-static-str/33993/4
#[macro_export]
macro_rules! physical_replication_slot_name {
() => {
"physical_wal_receiver_1"
};
}
#[derive(Serialize, Debug, Clone, Deserialize)]
pub enum PgConfigStyle {
Pgrx(PgVersion),
Env,
Path(PathBuf),
}
impl Default for PgConfigStyle {
fn default() -> Self {
PgConfigStyle::Pgrx(PgVersion::default())
}
}
impl PgConfigStyle {
pub fn pg_config(&self, port: Option) -> PgConfig {
match self {
PgConfigStyle::Pgrx(version) => {
let pgrx = Pgrx::from_config().expect("is pgrx configured?");
let base_pg_config = pgrx
.get(&version.to_string())
.expect("is pgrx configured with Postgres v17?");
PgConfig::new(
base_pg_config.path().unwrap(),
port.unwrap_or_else(default_port),
0,
)
}
PgConfigStyle::Env => {
let base_pg_config = PgConfig::from_path();
PgConfig::new(
base_pg_config.path().unwrap(),
port.unwrap_or_else(default_port),
0,
)
}
PgConfigStyle::Path(path) => {
PgConfig::new(path.clone(), port.unwrap_or_else(default_port), 0)
}
}
}
}
#[derive(Serialize, Default, Debug, Clone, Deserialize)]
pub enum PostgresqlConf {
#[default]
Normal,
Publisher,
Subscriber,
WalReceiver,
Custom(String),
}
impl PostgresqlConf {
pub fn lines(&self) -> impl Iterator- + '_ {
match self {
PostgresqlConf::Normal => vec![],
PostgresqlConf::Publisher => {
vec!["wal_level=logical"]
}
PostgresqlConf::Subscriber => {
vec!["wal_level=replica", "max_wal_senders=4"]
}
PostgresqlConf::WalReceiver => {
vec![
"hot_standby=on",
"hot_standby_feedback=on",
concat!("primary_slot_name=", physical_replication_slot_name!()),
]
}
PostgresqlConf::Custom(s) => s.lines().collect::>(),
}
.into_iter()
.chain(vec![
"shared_preload_libraries=pg_search",
"log_line_prefix=%m [%p] [%x] [%a] ",
"log_error_verbosity=verbose",
"max_wal_size=8GB",
])
}
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub enum PgVersion {
V15,
V16,
#[default]
V17,
V18,
}
impl fmt::Display for PgVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PgVersion::V15 => write!(f, "pg15"),
PgVersion::V16 => write!(f, "pg16"),
PgVersion::V17 => write!(f, "pg17"),
PgVersion::V18 => write!(f, "pg18"),
}
}
}
impl FromStr for PgVersion {
type Err = String;
fn from_str(s: &str) -> Result {
match s.to_lowercase().as_str() {
"pg15" | "15" => Ok(PgVersion::V15),
"pg16" | "16" => Ok(PgVersion::V16),
"pg17" | "17" => Ok(PgVersion::V17),
"pg18" | "18" => Ok(PgVersion::V18),
_ => Err(format!(
"Invalid PostgreSQL version: {}. Expected 'pg17' or 'pg18'",
s
)),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ServerStyle {
Pgrx(PgVersion),
FromPath,
Automatic {
#[serde(default)]
pg_config: PgConfigStyle,
#[serde(default = "default_port")]
port: u16,
log_path: Option,
pgdata: Option,
#[serde(default)]
postgresql_conf: PostgresqlConf,
},
With {
connection_string: String,
},
}
impl Default for ServerStyle {
fn default() -> Self {
ServerStyle::Pgrx(PgVersion::default())
}
}
impl ServerStyle {
pub fn port(&self) -> u16 {
match self {
ServerStyle::Pgrx(version) => PgConfigStyle::Pgrx(version.clone())
.pg_config(None)
.port()
.expect("`pgrx` should be installed"),
ServerStyle::FromPath => PgConfigStyle::Env
.pg_config(None)
.port()
.expect("`pg_config` not found"),
ServerStyle::Automatic {
port, pg_config, ..
} => pg_config
.pg_config(Some(*port))
.port()
.expect("`pg_config` not found"),
ServerStyle::With { connection_string } => {
let url = url::Url::parse(connection_string).expect("invalid connection string");
url.port_or_known_default()
.expect("no port found in connection string")
}
}
}
pub fn connstr(&self) -> String {
match self {
ServerStyle::Pgrx(_) => {
format!("host=localhost port={} dbname=stressgres", self.port())
}
ServerStyle::FromPath => {
format!("host=localhost port={} dbname=stressgres", self.port())
}
ServerStyle::Automatic { .. } => {
format!("host=localhost port={} dbname=stressgres", self.port())
}
ServerStyle::With { connection_string } => connection_string.clone(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Server {
#[serde(default)]
pub default: bool,
#[serde(deserialize_with = "validate_server_name")]
pub name: String,
#[serde(default)]
pub style: ServerStyle,
pub setup: Job,
pub teardown: Job,
pub monitor: Job,
}
fn validate_server_name<'de, D>(d: D) -> Result
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(d)?;
if s.chars().any(|c| !c.is_ascii_alphanumeric() && c != '_') {
Err(serde::de::Error::custom(format!(
"invalid server name `{s}`. Only `[a-zA-Z0-9_]` are supported"
)))
} else {
Ok(s)
}
}
fn default_port() -> u16 {
static LAST_PORT: AtomicU16 = AtomicU16::new(55500);
LAST_PORT.fetch_add(1, Ordering::Relaxed)
}
/// A single job in the suite.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Job {
pub title: Option,
pub on_connect: Option,
pub sql: String,
pub assert: Option,
pub window_height: Option,
pub cancel_keycode: Option,
pub pause_keycode: Option,
pub cancel_every: Option,
#[serde(default)]
pub atomic_connection: bool,
/// measured in milliseconds
#[serde(default = "default_refresh")]
pub refresh_ms: usize,
/// If true, log `tps=...`.
#[serde(default = "default_log_tps")]
pub log_tps: bool,
/// Arbitrary column names (e.g. block_count, segment_count) to include in the logs
#[serde(default)]
pub log_columns: Vec,
#[serde(default)]
#[serde(deserialize_with = "deserialize_destinations")]
pub destinations: Vec,
}
fn deserialize_destinations<'de, D>(d: D) -> Result, D::Error>
where
D: serde::Deserializer<'de>,
{
let names = Option::>::deserialize(d)?;
if names.is_none() {
return Ok(vec![StatementDestination::DefaultServer]);
}
let mut destinations: Vec = Vec::new();
for name in names.unwrap() {
destinations.push(match name.to_lowercase().as_str() {
"default" => StatementDestination::DefaultServer,
"all" => StatementDestination::AllServers,
_ => StatementDestination::SpecificServers(vec![name]),
});
}
Ok(destinations)
}
impl Default for Job {
fn default() -> Self {
Self {
title: None,
on_connect: None,
sql: "".to_string(),
assert: None,
window_height: None,
cancel_keycode: None,
pause_keycode: None,
cancel_every: None,
atomic_connection: false,
refresh_ms: 0,
log_tps: false,
log_columns: vec![],
destinations: vec![],
}
}
}
impl Job {
pub fn destinations(&self) -> Vec {
if self.destinations.is_empty() {
vec![StatementDestination::DefaultServer]
} else {
self.destinations.clone()
}
}
}
fn default_refresh() -> usize {
1000
}
fn default_log_tps() -> bool {
true
}
/// A full suite of jobs, plus optional name, setup, teardown, monitor.
#[derive(Deserialize, Debug)]
pub struct SuiteDefinition {
/// The file path to the suite definition.
#[serde(skip_serializing)]
pub path: Option,
/// The display name of the suite.
pub name: Option,
/// The list of jobs to run as part of the suite.
pub jobs: Vec,
/// The list of servers (Postgres instances) involved in the suite.
#[serde(deserialize_with = "validate_server_list")]
#[serde(rename = "server")]
pub servers: Vec,
/// A list of error message substrings that should be ignored during execution and termination.
#[serde(default)]
pub ignore_errors: Vec,
}
pub struct Suite {
definition: SuiteDefinition,
server_lookup: Arc>,
}
#[rustfmt::skip]
fn validate_server_list<'de, D>(d: D) -> Result, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut servers = Vec::::deserialize(d)?;
if !servers.is_empty() {
let mut found_default = false;
for server in &servers {
if server.default {
if found_default {
return Err(serde::de::Error::custom("cannot have multiple default servers"));
}
found_default = true;
}
}
if !found_default {
servers[0].default = true;
}
}
Ok(servers)
}
impl Server {
pub fn connstr(&self) -> String {
self.style.connstr()
}
pub fn is_subscriber(&self) -> bool {
matches!(
self.style,
ServerStyle::Automatic {
postgresql_conf: PostgresqlConf::Subscriber,
..
}
)
}
pub fn port(&self) -> u16 {
self.style.port()
}
}
impl Suite {
pub fn new(definition: SuiteDefinition) -> Self {
let server_lookup = definition
.servers
.iter()
.map(|server| (server.name.clone(), server.clone()))
.collect();
Self {
definition,
server_lookup: Arc::new(server_lookup),
}
}
pub fn name(&self) -> String {
self.definition.name.clone().unwrap_or_else(|| {
self.definition
.path
.as_ref()
.map(|p| p.display().to_string())
.unwrap_or_else(|| "".to_string())
})
}
pub fn ignore_errors(&self) -> &[String] {
&self.definition.ignore_errors
}
pub fn jobs(&self) -> impl Iterator
- {
self.definition.jobs.iter()
}
pub fn server(&self, name: &str) -> Option<&Server> {
self.server_lookup.get(name)
}
pub fn all_servers(&self) -> impl Iterator
- {
self.definition.servers.iter()
}
pub fn server_lookup(&self) -> Arc> {
self.server_lookup.clone()
}
pub fn default_server(&self) -> &Server {
for server in &self.definition.servers {
if server.default {
return server;
}
}
unreachable!("there should be a `[[server]]` configuration with `default = true`")
}
}
impl Job {
/// Return the user-provided or derived job title.
pub fn title(&self) -> String {
if let Some(t) = &self.title {
return t.trim().to_string();
}
// If no title was given, derive from the first statement
let statements = self.sql();
if statements.is_empty() {
"".to_string()
} else {
statements[0].sql.trim().to_string()
}
}
pub fn is_select(&self) -> bool {
self.sql()
.last()
.map(|stmt| {
stmt.sql.to_ascii_uppercase().starts_with("SELECT")
|| stmt.sql.to_ascii_uppercase().starts_with("EXPLAIN")
})
.unwrap_or_default()
}
/// Return the parsed statements for this job to run when the connection is first opened
pub fn on_connect(&self) -> Vec> {
if let Some(on_connect) = &self.on_connect {
SqlStatementScanner::new(on_connect)
.into_iter()
.map(|mut st| {
st.sql = st.sql.trim();
st
})
.filter(|st| !st.sql.is_empty())
.collect()
} else {
Vec::new()
}
}
/// Return the parsed statements for this job.
pub fn sql(&self) -> Vec> {
SqlStatementScanner::new(&self.sql)
.into_iter()
.map(|mut st| {
st.sql = st.sql.trim();
st
})
.filter(|st| !st.sql.is_empty())
.collect()
}
}