// Copyright (c) 2023-2026 ParadeDB, Inc. // // This file is part of ParadeDB - Postgres for Search and Analytics // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Tests for ParadeDB's Aggregate Custom Scan implementation mod fixtures; use fixtures::*; use pretty_assertions::assert_eq; use rstest::*; use serde_json::Value; use sqlx::PgConnection; fn assert_uses_custom_scan(conn: &mut PgConnection, enabled: bool, query: impl AsRef) { let (plan,) = format!(" EXPLAIN (FORMAT JSON) {}", query.as_ref()).fetch_one::<(Value,)>(conn); eprintln!("{plan:#?}"); assert_eq!( enabled, plan.to_string().contains("ParadeDB Aggregate Scan") ); } #[rstest] fn test_count(mut conn: PgConnection) { SimpleProductsTable::setup().execute(&mut conn); // Use the aggregate custom scan only if it is enabled. for enabled in [true, false] { format!("SET paradedb.enable_aggregate_custom_scan TO {enabled};").execute(&mut conn); let query = "SELECT COUNT(*) FROM paradedb.bm25_search WHERE description @@@ 'keyboard'"; assert_uses_custom_scan(&mut conn, enabled, query); let (count,) = query.fetch_one::<(i64,)>(&mut conn); assert_eq!(count, 2, "With custom scan: {enabled}"); } } #[rstest] fn test_count_with_group_by(mut conn: PgConnection) { SimpleProductsTable::setup().execute(&mut conn); "SET paradedb.enable_aggregate_custom_scan TO on;".execute(&mut conn); "SET client_min_messages TO warning;".execute(&mut conn); // First test simple COUNT(*) without GROUP BY let simple_count = "SELECT COUNT(*) FROM paradedb.bm25_search"; eprintln!("Testing simple COUNT(*)"); let (plan,) = format!("EXPLAIN (FORMAT JSON) {simple_count}").fetch_one::<(Value,)>(&mut conn); eprintln!("Simple COUNT(*) plan: {plan:#?}"); eprintln!( "Uses aggregate scan: {}", plan.to_string().contains("ParadeDB Aggregate Scan") ); // Test COUNT(*) with WHERE clause (like the working test) let count_with_where = "SELECT COUNT(*) FROM paradedb.bm25_search WHERE description @@@ 'keyboard'"; eprintln!("\nTesting COUNT(*) with WHERE clause"); let (plan,) = format!("EXPLAIN (FORMAT JSON) {count_with_where}").fetch_one::<(Value,)>(&mut conn); eprintln!( "COUNT(*) with WHERE plan uses aggregate scan: {}", plan.to_string().contains("ParadeDB Aggregate Scan") ); // Then test WITHOUT WHERE clause but WITH GROUP BY let query_no_where = r#" SELECT rating, COUNT(*) FROM paradedb.bm25_search GROUP BY rating ORDER BY rating "#; eprintln!("Testing query without WHERE clause"); let (plan,) = format!("EXPLAIN (FORMAT JSON) {query_no_where}").fetch_one::<(Value,)>(&mut conn); eprintln!("Plan without WHERE: {plan:#?}"); eprintln!( "Uses aggregate scan: {}", plan.to_string().contains("ParadeDB Aggregate Scan") ); // Then test WITH WHERE clause let query = r#" SELECT rating, COUNT(*) FROM paradedb.bm25_search WHERE description @@@ 'shoes' GROUP BY rating ORDER BY rating "#; // Verify it uses the aggregate custom scan assert_uses_custom_scan(&mut conn, true, query); // Execute and verify results let results: Vec<(i32, i64)> = query.fetch(&mut conn); assert_eq!(results.len(), 3); // We should have 3 distinct ratings for shoes assert_eq!(results[0], (3, 1)); // rating 3, count 1 assert_eq!(results[1], (4, 1)); // rating 4, count 1 assert_eq!(results[2], (5, 1)); // rating 5, count 1 } #[rstest] fn test_group_by(mut conn: PgConnection) { SimpleProductsTable::setup().execute(&mut conn); "SET paradedb.enable_aggregate_custom_scan TO on;".execute(&mut conn); // Supports GROUP BY with aggregate scan assert_uses_custom_scan( &mut conn, true, r#" SELECT rating, COUNT(*) FROM paradedb.bm25_search WHERE description @@@ 'keyboard' GROUP BY rating ORDER BY rating "#, ); } #[rstest] fn test_group_by_null_bucket(mut conn: PgConnection) { SimpleProductsTable::setup().execute(&mut conn); "SET paradedb.enable_aggregate_custom_scan TO on;".execute(&mut conn); assert_uses_custom_scan( &mut conn, true, r#" SELECT rating, COUNT(*) FROM paradedb.bm25_search WHERE description @@@ 'keyboard' GROUP BY rating ORDER BY rating NULLS FIRST "#, ); } #[rstest] fn test_no_bm25_index(mut conn: PgConnection) { "CALL paradedb.create_bm25_test_table(table_name => 'no_bm25', schema_name => 'paradedb');" .execute(&mut conn); "SET paradedb.enable_aggregate_custom_scan TO on;".execute(&mut conn); // Do not use the aggregate custom scan on non-bm25 indexed tables. assert_uses_custom_scan(&mut conn, false, "SELECT COUNT(*) FROM paradedb.no_bm25"); } #[rstest] fn test_other_aggregates(mut conn: PgConnection) { SimpleProductsTable::setup().execute(&mut conn); "SET paradedb.enable_aggregate_custom_scan TO on;".execute(&mut conn); for aggregate_func in ["SUM(rating)", "AVG(rating)", "MIN(rating)", "MAX(rating)"] { assert_uses_custom_scan( &mut conn, true, format!( r#" SELECT {aggregate_func} FROM paradedb.bm25_search WHERE description @@@ 'keyboard' "# ), ); } }