//
// Syd: rock-solid application kernel
// src/mask.rs: Utilities to mask sensitive information in proc files
//
// Copyright (c) 2025 Ali Polatel <alip@chesswob.org>
// SPDX-License-Identifier: GPL-3.0

// SAFETY: This module has been liberated from unsafe code!
#![forbid(unsafe_code)]

use std::os::fd::AsFd;

use memchr::{memchr, memmem};
use nix::{errno::Errno, unistd::read};

use crate::fs::{retry_on_eintr, write_all};

struct Patch {
    prefix: &'static [u8],
    needle: &'static [u8],
    repl: &'static [u8],
}

const PROC_STATUS_ZERO_FIELDS: &[&[u8]] = &[
    b"TracerPid:",
    b"NoNewPrivs:",
    b"Seccomp:",
    b"Seccomp_filters:",
];

const PROC_STATUS_SPEC_PATCHES: &[Patch] = &[
    Patch {
        prefix: b"Speculation_Store_Bypass:",
        needle: b"force mitigated",
        repl: b"vulnerable",
    },
    Patch {
        prefix: b"SpeculationIndirectBranch:",
        needle: b"force disabled",
        repl: b"enabled",
    },
];

// Zero out security-sensitive proc_pid_status(5) fields.
#[inline]
fn proc_status_mask_num<Fd: AsFd>(out: Fd, line: &[u8]) -> Result<bool, Errno> {
    for &field in PROC_STATUS_ZERO_FIELDS {
        if memmem::find(line, field) != Some(0) {
            continue;
        }

        // Preserve field+exact whitespace after colon.
        let mut i = field.len();
        #[allow(clippy::arithmetic_side_effects)]
        while i < line.len() {
            let b = line[i];
            if b == b' ' || b == b'\t' {
                i += 1;
            } else {
                break;
            }
        }

        // If there's a digit run and it's not already exactly "0",
        // write prefix+ws, then "0\n".
        let start = i;
        #[allow(clippy::arithmetic_side_effects)]
        while i < line.len() && line[i].is_ascii_digit() {
            i += 1;
        }
        let end = i;

        #[allow(clippy::arithmetic_side_effects)]
        if end == start || (end == start + 1 && line[start] == b'0') {
            return write_all(&out, line).map(|_| true);
        }

        write_all(&out, &line[..start])?;
        write_all(&out, b"0\n")?;
        return Ok(true);
    }
    Ok(false)
}

// Revert speculative execution fields back to default.
#[inline]
fn proc_status_patch_spec<Fd: AsFd>(out: Fd, line: &[u8]) -> Result<bool, Errno> {
    for p in PROC_STATUS_SPEC_PATCHES {
        if memmem::find(line, p.prefix) != Some(0) {
            continue;
        }
        // Search only in the value area to avoid double-scanning the prefix.
        #[allow(clippy::arithmetic_side_effects)]
        if let Some(pos_rel) = memmem::find(&line[p.prefix.len()..], p.needle) {
            let pos = p.prefix.len() + pos_rel;
            write_all(&out, &line[..pos])?;
            write_all(&out, p.repl)?;
            write_all(&out, &line[pos + p.needle.len()..])?;
            return Ok(true);
        } else {
            return write_all(out, line).map(|_| true);
        }
    }
    Ok(false)
}

// Emit next proc_pid_status(5) line, masking information as necessary.
#[inline]
fn proc_status_emit<Fd: AsFd>(out: Fd, line: &[u8]) -> Result<(), Errno> {
    if proc_status_mask_num(&out, line)? {
        return Ok(());
    }
    if proc_status_patch_spec(&out, line)? {
        return Ok(());
    }
    write_all(out, line)
}

/// Masks security-sensitive information in proc_pid_status(5).
pub(crate) fn mask_proc_pid_status<S: AsFd, D: AsFd>(src: S, dst: D) -> Result<(), Errno> {
    let mut buf = [0u8; 8192];
    let mut carry: Vec<u8> = Vec::new();

    loop {
        let n = retry_on_eintr(|| read(&src, &mut buf))?;
        if n == 0 {
            break;
        }

        let mut chunk = &buf[..n];
        #[allow(clippy::arithmetic_side_effects)]
        while let Some(nl) = memchr(b'\n', chunk) {
            let split = nl + 1;
            carry.try_reserve(split).map_err(|_| Errno::ENOMEM)?;
            carry.extend_from_slice(&chunk[..split]);
            proc_status_emit(&dst, &carry)?;
            carry.clear();
            chunk = &chunk[split..];
        }
        if !chunk.is_empty() {
            carry.try_reserve(chunk.len()).map_err(|_| Errno::ENOMEM)?;
            carry.extend_from_slice(chunk);
        }
    }

    if !carry.is_empty() {
        // proc_pid_status(5) lines are newline-terminated,
        // but handle partial last line defensively.
        proc_status_emit(&dst, &carry)?;
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use nix::unistd::{pipe, write};

    use super::*;

    /// Helper: run mask_proc_pid_status over `input` bytes,
    /// return produced bytes.
    fn run_mask(input: &[u8]) -> Result<Vec<u8>, Errno> {
        // input pipe
        let (in_rd, in_wr) = pipe()?;
        // output pipe
        let (out_rd, out_wr) = pipe()?;

        // write all input then close writer so the reader sees EOF.
        {
            let mut off = 0;
            while off < input.len() {
                match write(&in_wr, &input[off..]) {
                    Ok(0) => break,
                    Ok(n) => off += n,
                    Err(e) => return Err(e),
                }
            }
        }
        drop(in_wr); // close input writer

        // run the masker
        mask_proc_pid_status(&in_rd, &out_wr)?;

        // close output writer so we can read EOF.
        drop(out_wr);

        // read all output
        let mut out = Vec::new();
        let mut buf = [0u8; 4096];
        loop {
            match read(&out_rd, &mut buf) {
                Ok(0) => break,
                Ok(n) => out.extend_from_slice(&buf[..n]),
                Err(e) => return Err(e),
            }
        }

        Ok(out)
    }

    #[test]
    fn test_mask_proc_pid_status_zero_simple_fields() {
        let input = b"TracerPid:\t123\nNoNewPrivs:\t1\nSeccomp:\t2\nSeccomp_filters:\t7\n";
        let out = run_mask(input).unwrap();
        let expected = b"TracerPid:\t0\nNoNewPrivs:\t0\nSeccomp:\t0\nSeccomp_filters:\t0\n";
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_mask_proc_pid_status_preserve_whitespace() {
        // mixed spaces/tabs must be preserved before the value
        let input = b"TracerPid:\t   456\nSeccomp:\t\t  2\n";
        let out = run_mask(input).unwrap();
        let expected = b"TracerPid:\t   0\nSeccomp:\t\t  0\n";
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_mask_proc_pid_status_zero_already_zero_passthrough() {
        let input = b"TracerPid:\t0\nNoNewPrivs:\t0\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    #[test]
    fn test_mask_proc_pid_status_spec_store_bypass_patch() {
        let input = b"Speculation_Store_Bypass:\tthread force mitigated\n";
        let out = run_mask(input).unwrap();
        let expected = b"Speculation_Store_Bypass:\tthread vulnerable\n";
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_mask_proc_pid_status_spec_indirect_branch_patch() {
        let input = b"SpeculationIndirectBranch:\tconditional force disabled\n";
        let out = run_mask(input).unwrap();
        let expected = b"SpeculationIndirectBranch:\tconditional enabled\n";
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_mask_proc_pid_status_spec_lines_already_ok_passthrough() {
        let input = b"Speculation_Store_Bypass:\tthread vulnerable\n\
                      SpeculationIndirectBranch:\tconditional enabled\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    #[test]
    fn test_mask_proc_pid_status_other_lines_unchanged() {
        let input = b"Name:\tcat\nState:\tS (sleeping)\nThreads:\t4\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    #[test]
    fn test_mask_proc_pid_status_prefix_must_be_line_start() {
        // "Seccomp:" appears later in the line — must not be treated as a field header.
        let input = b"Name:\tSeccomp:\t2 (not a header)\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    #[test]
    fn test_mask_proc_pid_status_long_line_carry_and_zeroing() {
        // Build a very long TracerPid line (> 9000 bytes) to cross internal buffer boundaries.
        let mut line = b"TracerPid:\t".to_vec();
        line.extend(std::iter::repeat(b'9').take(9000));
        line.push(b'\n');
        let out = run_mask(&line).unwrap();
        let expected = b"TracerPid:\t0\n".to_vec();
        assert_eq!(out, expected);
    }

    #[test]
    fn test_mask_proc_pid_status_long_nonmatching_passthrough() {
        let mut line = vec![b'A'; 10000];
        line.push(b'\n');
        let out = run_mask(&line).unwrap();
        assert_eq!(out, line);
    }

    #[test]
    fn test_mask_proc_pid_status_combined_document_full() {
        let input = concat!(
            "Name:\tmyproc\n",
            "TracerPid:\t42\n",
            "Speculation_Store_Bypass:\tthread force mitigated\n",
            "NoNewPrivs:\t1\n",
            "SpeculationIndirectBranch:\tconditional force disabled\n",
            "Seccomp:\t2\n",
            "Threads:\t5\n",
            "Seccomp_filters:\t3\n",
        )
        .as_bytes();

        let expected = concat!(
            "Name:\tmyproc\n",
            "TracerPid:\t0\n",
            "Speculation_Store_Bypass:\tthread vulnerable\n",
            "NoNewPrivs:\t0\n",
            "SpeculationIndirectBranch:\tconditional enabled\n",
            "Seccomp:\t0\n",
            "Threads:\t5\n",
            "Seccomp_filters:\t0\n",
        )
        .as_bytes();

        let out = run_mask(input).unwrap();
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_mask_proc_pid_status_suffix_after_digits_is_dropped() {
        // Even if garbage trails the number (shouldn't happen in real /proc), we force a clean "0\n".
        let input = b"TracerPid:\t123 extra_garbage\n";
        let out = run_mask(input).unwrap();
        let expected = b"TracerPid:\t0\n";
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_mask_proc_pid_status_no_final_newline_passthrough_nonmatching() {
        // Handle trailing line without newline (not a masked field): proc_status_emit exactly as-is.
        let input = b"Name:\tno_nl_at_end";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }
}
