use reqwest;
use serde::{Deserialize, Serialize};
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
use std::fs;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;

#[derive(Debug, Deserialize, Serialize)]
struct ChatData {
    output: String,
}

// benchmark to evaluate latency overhead from the insert trigger
async fn bench_insert_triggers() {
    let database_url = std::env::var("DATABASE_URL").unwrap();

    let pool = PgPoolOptions::new()
        .max_connections(5)
        .connect(&database_url)
        .await
        .unwrap();

    sqlx::query("DROP TABLE IF EXISTS nemotron_chat CASCADE;")
        .execute(&pool)
        .await
        .unwrap();

    // Create table if it doesn't exist
    sqlx::query(
        "CREATE TABLE IF NOT EXISTS nemotron_chat (
            id SERIAL PRIMARY KEY,
            output TEXT NOT NULL
        )",
    )
    .execute(&pool)
    .await
    .unwrap();

    // initialize vectorize job on the empty table
    sqlx::query(
        "SELECT vectorize.table(
        job_name => 'nemotron_chat',
        relation => 'nemotron_chat',
        primary_key => 'id',
        columns => ARRAY['output'],
        transformer => 'sentence-transformers/all-MiniLM-L6-v2',
        schedule => 'realtime'
    );",
    )
    .execute(&pool)
    .await
    .expect("failed to init job");

    let bench_data = download_dataset().await;

    insert_data(&pool, bench_data).await;

    println!("Data loaded successfully!");
}

fn read_jsonl_file(path: &str) -> Vec<ChatData> {
    let file = File::open(path).unwrap();
    let reader = BufReader::new(file);
    let mut items = Vec::new();

    for line in reader.lines() {
        let line = line.unwrap();
        match serde_json::from_str::<ChatData>(&line) {
            Ok(item) => items.push(item),
            Err(e) => eprintln!("Failed to parse line: {}", e),
        }
    }
    items
}

async fn download_dataset() -> Vec<ChatData> {
    let file_path = "chat.jsonl";
    let url = "https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset-v1/resolve/main/SFT/chat/chat.jsonl";

    // check if the file exists locally, download if not found
    if !Path::new(file_path).exists() {
        println!("File not found locally. Downloading from {url}");
        let client = reqwest::Client::new();
        let response = client.get(url).send().await.unwrap();
        let content = response.text().await.unwrap();

        // Save the downloaded content to local file
        fs::write(file_path, &content).unwrap();
    } else {
        println!("File found locally.");
    }
    let data: Vec<ChatData> = read_jsonl_file(file_path);

    data
}

async fn insert_data(pool: &Pool<Postgres>, data: Vec<ChatData>) {
    let start = std::time::Instant::now();
    let num_rows = data.len();

    let mut query_builder = String::from("INSERT INTO nemotron_chat (output) VALUES ");
    let mut param_index = 1;

    for (i, _) in data.iter().enumerate() {
        if i > 0 {
            query_builder.push_str(", ");
        }
        query_builder.push_str(&format!("(${})", param_index));
        param_index += 1;
    }
    query_builder.push(';');
    let mut query = sqlx::query(&query_builder);

    // Bind all parameters
    for example in data.iter() {
        query = query.bind(&example.output);
    }
    query.execute(pool).await.unwrap();

    let duration = start.elapsed();
    println!("Time elapsed: {:?}, num records: {}", duration, num_rows);
}

#[tokio::main]
async fn main() {
    bench_insert_triggers().await;
}