// Copyright (c) 2023-2026 ParadeDB, Inc. // // This file is part of ParadeDB - Postgres for Search and Analytics // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . use anyhow::{bail, Context, Result}; use clap::Parser; use duckdb::Connection; use std::time::Instant; use crate::config::{load_dataset_config, topological_order}; use crate::utils::{open_duckdb_conn, validate_input_output}; #[derive(Parser)] pub struct SampleArgs { /// Input path to the dataset (S3 or local). /// Each table is a subdirectory containing parquet files. #[arg(long)] pub input: String, /// Output path for the sampled parquet files (S3 or local). /// Parquet files will be written to subdirectories matching the table names. #[arg(long)] pub output: String, /// Path to the TOML config file describing table relationships. #[arg(long)] pub config: String, /// Target number of rows for the root table. #[arg(long)] pub rows: u64, /// Validate and report planned row counts without writing. #[arg(long, default_value_t = false)] pub dry_run: bool, } fn parquet_glob_pattern(base: &str, table_name: &str) -> String { format!("{base}/{table_name}/*.parquet") } fn count_rows(conn: &Connection, glob: &str) -> Result { let sql = format!("SELECT count(*) FROM read_parquet('{glob}')"); let count: u64 = conn .query_row(&sql, [], |row| row.get(0)) .with_context(|| format!("Failed to count rows for '{glob}'"))?; Ok(count) } pub fn run_sample(args: SampleArgs) -> Result<()> { let config = load_dataset_config(&args.config)?; let conn = open_duckdb_conn()?; let input = args.input.trim_end_matches('/'); let output = args.output.trim_end_matches('/'); let table_names: Vec = config.tables.iter().map(|t| t.name.clone()).collect(); validate_input_output(&table_names, &conn, input, output)?; // Determine processing order. let order = topological_order(&config)?; // Sample the root table. let root = &config.tables[order[0]]; let root_glob = parquet_glob_pattern(input, &root.name); let total_rows = count_rows(&conn, &root_glob)?; if total_rows == 0 { bail!("Root table '{}' has no rows", root.name); } let target = args.rows; if target > total_rows { bail!( "Target rows ({target}) exceeds total rows ({total_rows}) in root table '{}'", root.name ); } let local_root_data_path = format!("/tmp/local_source/{}", root.name); let local_glob = format!("{local_root_data_path}/*.parquet"); // copy root table locally to speed up sampling. std::fs::create_dir_all(&local_root_data_path) .with_context(|| "Failed to make root table data directory")?; println!("Copying root table data to local disk..."); let sql = format!( "COPY (SELECT * FROM read_parquet('{}')) \ TO '{}' (FORMAT PARQUET, OVERWRITE true, PER_THREAD_OUTPUT true)", root_glob, local_root_data_path ); conn.execute_batch(&sql) .with_context(|| "Failed to copy root table data locally")?; // disable multi-threading, required for deterministic output // See: https://duckdb.org/docs/current/sql/samples#syntax conn.execute("SET threads = 1;", []) .with_context(|| "Failed to set thread count")?; let percentage = (target as f64 / total_rows as f64) * 100.0; // We use reservoir for small sample sizes, since it allows us to be exact. However, it // requires materializing the entire sample in memory, so we use the system method for larger // counts, which gives us an approximate count (usually within 3-5%). let sample_arg = if target <= 100_000 { format!("reservoir({target} ROWS)") } else { format!("system({percentage:.5} PERCENT)") }; let sql = format!( "CREATE TABLE sampled_{name} AS \ SELECT * FROM read_parquet('{local_path}') \ USING SAMPLE {sample_arg} REPEATABLE({seed})", name = root.name, local_path = local_glob, sample_arg = sample_arg, seed = config.sampling_seed, ); println!( "Sampling root table {} for ~{} rows ({:.5} percent of the input)...", root.name, target, percentage ); let start_time = Instant::now(); conn.execute_batch(&sql) .with_context(|| format!("Failed to sample root table '{}'", root.name))?; println!("Sampling took: {:?}", start_time.elapsed()); println!("Removing root table data from local disk..."); std::fs::remove_dir_all(&local_root_data_path) .with_context(|| format!("Failed to remove dir: '{}'", &local_root_data_path))?; // re-enable multi-threading conn.execute("RESET threads;", []) .with_context(|| "Failed to reset thread count to default")?; let sampled_root_count: u64 = conn .query_row( &format!("SELECT count(*) FROM sampled_{}", root.name), [], |row| row.get(0), ) .context("Failed to count sampled root rows")?; println!(" {} sampled: {sampled_root_count} rows", root.name); // Sample child tables by joining against their sampled parent. for &idx in &order[1..] { let table = &config.tables[idx]; let parent = table.parent.as_ref().unwrap(); let parent_join_key = table.parent_join_col.as_ref().unwrap(); let join_key = table.join_col.as_ref().unwrap(); let glob = parquet_glob_pattern(input, &table.name); println!( "Sampling child table '{}' (parent: '{parent}')...", table.name ); let sql = format!( "CREATE TABLE sampled_{name} AS \ SELECT DISTINCT c.* \ FROM read_parquet('{glob}') c \ WHERE c.\"{jk}\" IN (SELECT {pk} from sampled_{parent})", name = table.name, parent = parent, jk = join_key, pk = parent_join_key, ); conn.execute_batch(&sql) .with_context(|| format!("Failed to sample child table '{}'", table.name))?; let child_count: u64 = conn .query_row( &format!("SELECT count(*) FROM sampled_{}", table.name), [], |row| row.get(0), ) .with_context(|| format!("Failed to count sampled rows for '{}'", table.name))?; println!(" {} sampled: {child_count} rows", table.name); } if args.dry_run { println!("\nDry run complete. No files were written."); return Ok(()); } // Write output. println!("\nWriting sampled parquet files..."); for &idx in &order { let table = &config.tables[idx]; println!(" Writing '{}'...", table.name); let sql = format!( "COPY sampled_{name} TO '{output}/{name}' (FORMAT PARQUET, PER_THREAD_OUTPUT true)", name = table.name, output = output, ); conn.execute_batch(&sql) .with_context(|| format!("Failed to write sampled table '{}'", table.name))?; println!(" {}: done", table.name); } println!("\nSampling complete."); Ok(()) }