//
// Syd: rock-solid application kernel
// src/utils/syd-aes.rs: AES-GCM Encryption and Decryption Utility
//
// Copyright (c) 2024, 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use std::{
    os::unix::ffi::OsStrExt,
    process::ExitCode,
    time::{Duration, Instant},
};

use btoi::btoi;
use data_encoding::HEXLOWER_PERMISSIVE;
use nix::{
    errno::Errno,
    fcntl::{splice, OFlag, SpliceFFlags},
    unistd::{isatty, pipe2},
};
use syd::{
    config::PIPE_BUF,
    err::SydResult,
    hash::{
        aes_ctr_dec, aes_ctr_enc, aes_ctr_flush, aes_ctr_init, aes_ctr_setup, KeySerial,
        BLOCK_SIZE, IV,
    },
};
use zeroize::Zeroizing;

// Set global allocator to GrapheneOS allocator.
#[cfg(all(
    not(coverage),
    not(feature = "prof"),
    not(target_os = "android"),
    target_page_size_4k,
    target_pointer_width = "64"
))]
#[global_allocator]
static GLOBAL: hardened_malloc::HardenedMalloc = hardened_malloc::HardenedMalloc;

// Set global allocator to tcmalloc if profiling is enabled.
#[cfg(feature = "prof")]
#[global_allocator]
static GLOBAL: tcmalloc::TCMalloc = tcmalloc::TCMalloc;

fn process_data(encrypting: bool, key_id: KeySerial, iv: IV, verbose: bool) -> SydResult<()> {
    let fd = aes_ctr_setup(key_id)?;
    let fd = aes_ctr_init(&fd, false)?;

    if encrypting {
        aes_ctr_enc(&fd, &[], Some(&iv), true)?;
    } else {
        aes_ctr_dec(&fd, &[], Some(&iv), true)?;
    }

    let (pipe_rd, pipe_wr) = pipe2(OFlag::O_CLOEXEC)?;

    let mut nread = 0;
    let stime = Instant::now();
    let mut ltime = stime;
    let mut nbytes = 0;
    let mut nwrite = 0;
    let verbose = if verbose {
        isatty(std::io::stderr()).unwrap_or(false)
    } else {
        false
    };

    // SAFETY: This buffer holds plaintext,
    // we zero it on free and ensure it never swaps out.
    let (mut bufz, mut bufu) = if !encrypting {
        (Some(Zeroizing::new(vec![0u8; PIPE_BUF])), None)
    } else {
        (None, Some(vec![0u8; PIPE_BUF]))
    };
    let buf = if let Some(ref mut bufz) = bufz {
        bufz.as_mut()
    } else if let Some(ref mut bufu) = bufu {
        bufu
    } else {
        unreachable!()
    };

    loop {
        // Use splice to move data from standard input to pipe.
        match splice(
            std::io::stdin(),
            None,
            &pipe_wr,
            None,
            PIPE_BUF,
            SpliceFFlags::empty(),
        )? {
            0 => break, // EOF
            n => {
                match splice(&pipe_rd, None, &fd, None, n, SpliceFFlags::SPLICE_F_MORE)? {
                    0 => break, // EOF
                    n => nread += n,
                };

                while nread >= BLOCK_SIZE {
                    let nblock = (nread / BLOCK_SIZE) * BLOCK_SIZE;
                    let n = aes_ctr_flush(&fd, std::io::stdout(), buf, nblock)?;
                    nread -= n;
                    nbytes += n;
                    nwrite += 1;
                }

                if verbose {
                    let now = Instant::now();
                    if now.duration_since(ltime) >= Duration::from_millis(500) {
                        let elapsed = stime.elapsed();
                        let speed = nbytes as f64 / elapsed.as_secs_f64();
                        let output = format!(
                            "{} bytes ({:.2} GB, {:.2} GiB) processed, {:.2?} s, {:.2} MB/s",
                            nbytes,
                            nbytes as f64 / 1_000_000_000.0,
                            nbytes as f64 / (1 << 30) as f64,
                            elapsed,
                            speed / (1 << 20) as f64
                        );
                        eprint!("\r\x1B[K{output}");
                        ltime = now;
                    }
                }
            }
        }
    }

    if nread > 0 {
        // Finalize {en,de}cryption to flush final batch with `false`.
        //
        // Some kernel versions may incorrectly return EINVAL here.
        // Gracefully handle this errno and move on.
        match if encrypting {
            aes_ctr_enc(&fd, &[], None, false)
        } else {
            aes_ctr_dec(&fd, &[], None, false)
        } {
            Ok(_) | Err(Errno::EINVAL) => {}
            Err(errno) => return Err(errno.into()),
        }

        // {En,De}crypt the final batch.
        // SAFETY: Zero-out memory if decrypting.
        aes_ctr_flush(&fd, std::io::stdout(), buf, nread)?;
        if verbose {
            nbytes += nread;
            nwrite += 1;
        }
    }

    if verbose {
        let elapsed = stime.elapsed();
        eprintln!(
            "\n{} records of each {} bytes processed.\n{} bytes ({:.2} GB, {:.2} GiB) processed, {:.5?} s, {:.2} MB/s",
            nwrite,
            PIPE_BUF,
            nbytes,
            nbytes as f64 / 1_000_000_000.0,
            nbytes as f64 / (1 << 30) as f64,
            elapsed,
            nbytes as f64 / elapsed.as_secs_f64() / (1 << 20) as f64
        );
    }

    Ok(())
}

syd::main! {
    use lexopt::prelude::*;

    syd::set_sigpipe_dfl()?;

    // Parse CLI options.
    let mut opt_encrypt = None;
    let mut opt_key_id = None;
    let mut opt_iv_hex = None;
    let mut opt_verbose = false;

    let mut parser = lexopt::Parser::from_env();
    while let Some(arg) = parser.next()? {
        match arg {
            Short('h') => {
                help();
                return Ok(ExitCode::SUCCESS);
            }
            Short('v') => opt_verbose = true,
            Short('e') => opt_encrypt = Some(true),
            Short('d') => opt_encrypt = Some(false),
            Short('k') => opt_key_id = Some(btoi::<KeySerial>(parser.value()?.as_bytes())?),
            Short('i') => opt_iv_hex = Some(parser.value()?.parse::<String>()?),
            _ => return Err(arg.unexpected().into()),
        }
    }

    let is_enc = if let Some(is_enc) = opt_encrypt {
        is_enc
    } else {
        eprintln!("syd-aes: Error: -e or -d options are required.");
        help();
        return Ok(ExitCode::FAILURE);
    };

    let key_id = if let Some(key_id) = opt_key_id {
        key_id
    } else {
        eprintln!("syd-aes: Error: -k option is required.");
        help();
        return Ok(ExitCode::FAILURE);
    };

    if opt_iv_hex.is_none() {
        eprintln!("syd-aes: Error: -i option is required.");
        help();
        return Ok(ExitCode::FAILURE);
    }
    let iv = match opt_iv_hex
        .and_then(|hex| HEXLOWER_PERMISSIVE.decode(hex.as_bytes()).ok())
        .and_then(|vec| vec.as_slice().try_into().ok())
    {
        Some(iv) => IV::new(iv),
        None => {
            eprintln!("syd-aes: Error: IV must be valid hex, and 128 bits (16 bytes) in length!");
            return Ok(ExitCode::FAILURE);
        }
    };

    process_data(is_enc, key_id, iv, opt_verbose).map(|_| ExitCode::SUCCESS)
}

fn help() {
    println!("Usage: syd-aes [-h] -e|-d -k <key-serial> -i <iv-hex>");
    println!("AES-CTR Encryption and Decryption Utility");
    println!("Reads from standard input and writes to standard output.");
    println!("  -h        Print this help message and exit.");
    println!("  -v        Enable verbose mode.");
    println!("  -e        Encrypt the input data.");
    println!("  -d        Decrypt the input data.");
    println!("  -k <key>  Key serial ID for keyrings(7) (32-bit integer)");
    println!("            Key must have search permission.");
    println!("  -i <iv>   Hex-encoded IV (128 bits)");
}
