#!/bin/bash
set -e
set +x
# Copyright (C) 2022-2023 Mo Zhou <lumin@debian.org>
# MIT/Expat License.
#
# Nvidia CUDA Deep Neural Network Library installer script (Debian Specific)
#
# XXX: you can browse the following directory for updating the
#  URL_{amd64,ppc64el,arm64} shell variables below:
#
#   NEW: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/
#   OLD: https://developer.download.nvidia.com/compute/redist/cudnn/
#
#  More references when Nvidia breaks the above link:
#
#   https://github.com/archlinux/svntogit-community/blob/packages/cudnn/trunk/PKGBUILD
#   https://gitlab.archlinux.org/archlinux/packaging/packages/cudnn/-/blob/main/PKGBUILD
#   https://github.com/pytorch/builder/blob/main/common/install_cuda.sh
#
# XXX: [maintainer/user notes]
#
#  In case you want to upgrade to a newer version of cuDNN, you can just
#  browse the link above, and find the binary tarballs you want.
#  Then copy the urls and update the corresponding URL_* variables
#  below. The rest of the shell code can remain unchanged as long
#  as the upstream does not alter the file paths in tarball.
#
#  To test whether the updated links work or not, you can just copy
#  and paste the commands at the end part of the usage() function.
#  You don't have to test all the three actions {-d, -u, -p}.
#  As long as the update (-u) action works without issue, the download
#  is good as well (-d).
#
#  For a more thorough testing of install/purge, use piuparts instead.

## configs ####################################################################
TMPDIR="$(mktemp -d)"
TMPDIR_IS_OVERRIDEN=0
ARCH="$(dpkg-architecture -qDEB_HOST_ARCH)"
MULTIARCH="$(dpkg-architecture -qDEB_HOST_MULTIARCH)"
PREFIX="/usr"
FILELIST=""
CUDA_VER="12"
CUDNN_VER="8.9.2.26"
URL_amd64="https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-8.9.2.26_cuda12-archive.tar.xz"
SHA512_amd64=e46a6abc348fe814f33c29d1596a4ec5a8041b61d9dc2b083b17f153a693a2493886288a43b3d8f682688da6d3a6f0a769cbe78b2b5570cf8009a195f7fe0ff4
URL_ppc64el="https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-ppc64le/cudnn-linux-ppc64le-8.9.2.26_cuda12-archive.tar.xz"
SHA512_ppc64el=cc64b4af704e5c3e3b322f9bbd93b7ed4907616bf2b2d04ba48c69f48344eef6a5524028020f8fef6e50d86ab69c81d7d2208188a13ad462d8bc43dd8a5fedf6
URL_arm64="https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-sbsa/cudnn-linux-sbsa-8.9.2.26_cuda12-archive.tar.xz"
SHA512_arm64=742f72bccbc58359a59c82ca9e8443ccd92dfff26b90e411bf392715b4180c6238724c674d715b032ee380b84c20519b3246aa4fd9cc0b49f96a93e1caf3beec

## usage ######################################################################
usage () {
    cat << EOF
Usage: $(basename $0) <-d|-u|-p|-h> [--prefix <prefix>] [--tmpdir <tmpdir>] ...
Arguments:
 -d|--download           download only (default: 0)
 -u|--update             update cudnn installation (default: 0)
 -p|--purge              purge cudnn installation (default: 0)
 -h|--help               display this help message
 --arch <arch>           override architecture (default: $(dpkg-architecture -qDEB_HOST_ARCH))
 --multiarch <multiarch> override multiarch triplet (default: $(dpkg-architecture -qDEB_HOST_MULTIARCH))
 --prefix <path>         override install prefix (default: /usr)
 --tmpdir <dir>          override temporary directory (default: ${TMPDIR})
 --filelist <path>       write installed file list (default: ${FILELIST})
Testing this script cross-architecture:
 $ ./update-nvidia-cudnn --arch amd64   --multiarch x86_64-linux-gnu      --tmpdir . --prefix fake {-d,-u,-p}
 $ ./update-nvidia-cudnn --arch ppc64el --multiarch powerpc64le-linux-gnu --tmpdir . --prefix fake {-d,-u,-p}
 $ ./update-nvidia-cudnn --arch arm64   --multiarch aarch64-linux-gnu     --tmpdir . --prefix fake {-d,-u,-p}
Version: cuDNN ${CUDNN_VER} for CUDA ${CUDA_VER}
EOF
}

## argument parsing ###########################################################
DOWNLOAD_ONLY=0
DO_UPDATE=0
DO_PURGE=0
while [[ $# -gt 0 ]]; do
    case $1 in
        -d|--download)
            DOWNLOAD_ONLY=1; shift;;
        -u|--update)
            DO_UPDATE=1; shift;;
        -p|--purge)
            DO_PURGE=1; shift;;
        --prefix)
            if test -n "$2"; then
                PREFIX="$2"
            fi
            shift; shift;;
        --arch)
            if test -n "$2"; then
                ARCH="$2"
            fi
            shift; shift;;
        --multiarch)
            if test -n "$2"; then
                MULTIARCH="$2"
            fi
            shift; shift;;
        --tmpdir)
            if test -n "$2"; then
                TMPDIR=$2
                TMPDIR_IS_OVERRIDEN=1
            fi
            shift; shift;;
        --filelist)
            if test -n "$2"; then
                FILELIST=$2
            fi
            shift; shift;;
        -h|--help)
            usage; exit 0;;
        -*|--*)
            usage; exit 1;;
        *)
            usage; exit 1;;
    esac
done
# post processing
if test ${ARCH} = "amd64"; then
    URL=${URL_amd64}; SHA512=${SHA512_amd64}
elif test ${ARCH} = "ppc64el"; then
    URL=${URL_ppc64el}; SHA512=${SHA512_ppc64el}
elif test ${ARCH} = "arm64"; then
    URL=${URL_arm64}; SHA512=${SHA512_arm64}
else
    echo $0: Unsupported architecture ${arch} >&2
fi

## functions ##################################################################
verify_checksum() {
    dest="$1"
    if ! sha512sum -c - << EOF; then
${SHA512} ${dest}
EOF
        echo "$0: sha512 checksum mismatch. aborting." >&2
        exit 1
    fi
}

download_cudnn () {
    # Download cudnn tarball to ${TMPDIR}
    # args: $1: URL for upstream tarball
    # return: saved file destination
    test -n "${1}" || (echo "download_cudnn(): URL not specified"; exit 1)
    local url="${1}"
    test -d ${TMPDIR} || mkdir ${TMPDIR}
    local dest="${TMPDIR}/$(basename ${url})"
    # detect downloader and download
    if (command -v curl > /dev/null); then
        local cmd="curl --continue-at - -L ${URL} --output ${dest}"
    elif (command -v wget > /dev/null); then
        local cmd="wget --continue --verbose --show-progress=off -c ${URL} -O ${dest}"
    else
        echo "$0: Error: no downloader available."
        exit 255
    fi
    # already exists?
    if ! test -f ${dest}; then
        echo ${cmd} >&2
        bash -c "${cmd}" || bash -c "${cmd} --no-check-certificate" >&2
    else
        echo Skipping download as file already exists: ${dest} >&2
    fi
    test -f ${dest} || (echo "Download failed."; exit 1) >&2
    # return string
    echo "${dest}"
}

install_cudnn () {
    test -n "${1}" || (echo "install_cudnn(): invalid argument"; exit 1)
    test -n "${2}" || (echo "install_cudnn(): invalid argument"; exit 1)
    # Install extracted cudnn from src to dst
    local src=${1}  # e.g. /tmp/nvidia-cudnn/
    local dst=${2}  # e.g. /usr/local/
    local installed=()
    FILES=( $(find ${src} -type f,l) )
    for F in ${FILES[@]}; do
        (echo ${F} | grep -qo "cudnn.txz") && continue
        if (echo ${F} | grep -qo ".*/libcudnn.*\.so.*"); then
            # shared object file
            if test -L ${F}; then
                mkdir -p ${dst}/lib/${MULTIARCH}/ || true
                cp -av ${F} ${dst}/lib/${MULTIARCH}/
            else
                install -vDm0644 -t ${dst}/lib/${MULTIARCH} ${F}
            fi
            installed=( ${installed[@]} ${dst}/lib/${MULTIARCH}/$(basename ${F}) )
        elif $(echo ${F} | grep -qo ".*/libcudnn.*\.a"); then
            # static library file
            install -vDm0644 -t ${dst}/lib/${MULTIARCH} ${F}
            installed=( ${installed[@]} ${dst}/lib/${MULTIARCH}/$(basename ${F}) )
        elif $(echo ${F} | grep -qo ".*/cudnn.*\.h"); then
            # header file
            install -vDm0644 -t ${dst}/include/${MULTIARCH} ${F}
            installed=( ${installed[@]} ${dst}/include/${MULTIARCH}/$(basename ${F}) )
        elif $(echo ${F} | grep -qo "NVIDIA_SLA_cuDNN_Support.txt"); then
            # copyright file
            install -vDm0644 -t ${dst}/share/doc/nvidia-cudnn/ ${F}
            installed=( ${installed[@]} ${dst}/share/doc/nvidia-cudnn/$(basename ${F}) )
        else
            echo Skipped ${F}
        fi
    done
    echo $0: Number of installed files: ${#installed[@]}
    if test -n "${FILELIST}"; then
        if ! test -d $(dirname "${FILELIST}"); then
            mkdir -p $(dirname "${FILELIST}")
        fi
        for i in ${installed[@]}; do
            echo "${i}" >> "${FILELIST}"
        done
        echo $0: The list of installed files is recorded at "${FILELIST}"
    fi
}

# this is a dispatcher. When a ${FILELIST} is given, we use it to safely
# delete the recorded installed files. If not, we fallback to the manual
# deletion based on manually written matching rules. As long as the user
# does not modify /usr/lib/ (irrelevant to /usr/local) on their own, both
# methods will lead to the same results.
purge_cudnn () {
    local prefix="${1}"
    if test -n "${FILELIST}"; then
        purge_cudnn_filelist "${FILELIST}"
    else
        purge_cudnn_fallback "${prefix}"
    fi
}

purge_cudnn_filelist () {
    test -e "${1}" || (echo "purge_cudnn_filelist(): invalid argument"; exit 1)
    # purge cudnn from the given file list
    local filelist="${1}"
    # first, validate the given file list
    local FL=( $(cat "${filelist}") )
    for i in ${FL[@]}; do
        if ! test -e "${i}"; then
            echo Error: the given file list ${filelist} is invalid: file ${i} does not exist
            exit 2
        fi
    done
    # then, remove the listed files
    for i in ${FL[@]}; do
        rm -v ${i}
    done
    rm -v "${filelist}"
}

purge_cudnn_fallback () {
    test -n "${1}" || (echo "purge_cudnn_fallback(): invalid argument"; exit 1)
    # Purge cudnn from the given path (prefix)
    local dst="${1}"
    FILES=( $(find ${dst}/lib/${MULTIARCH} -type f,l -name "libcudnn*.so*") )
    FILES+=( $(find ${dst}/include/${MULTIARCH} -type f -name "cudnn*.h") )
    FILES+=( $(find ${dst}/lib/${MULTIARCH} -type f -name "libcudnn*.a") )
    FILES+=( ${dst}/share/doc/nvidia-cudnn/NVIDIA_SLA_cuDNN_Support.txt )
    if test 0 -eq ${#FILES[@]}; then
        exit 0
    fi
    for F in ${FILES[@]}; do
        (test -e ${F} || test -L ${F}) && rm -rv ${F}
    done
}

# flag check: must select one valid action
test ${DOWNLOAD_ONLY} -eq 0 && \
    test ${DO_UPDATE} -eq 0 && \
    test ${DO_PURGE} -eq 0 && \
    (usage; exit 0)

# trigger actions
if test "${DOWNLOAD_ONLY}" -ne 0; then
    path=$(download_cudnn ${URL})
    verify_checksum ${path}
    echo ${path}
    exit 0
elif test "${DO_UPDATE}" -ne 0; then
    path=$(download_cudnn ${URL})
    verify_checksum ${path}
    tmpdir2=$(mktemp -d)
    echo Extracting files from the downloaded tarball...
    tar xvf ${path} -C ${tmpdir2}/
    echo Installing the files to system directories...
    install_cudnn ${tmpdir2} ${PREFIX}
    rm -rf ${tmpdir2}
    # cleanup
    if test 0 -eq ${TMPDIR_IS_OVERRIDEN}; then
        rm -rv ${path}
        rmdir -v ${TMPDIR}
    fi
elif test "${DO_PURGE}" -ne 0; then
    echo Purging cuDNN installation from ${PREFIX}
    purge_cudnn ${PREFIX} || true
fi
