/************************************************************************
 * Derived from the BSD3-licensed
 * LAPACK routine (version 3.7.0) --
 *     Univ. of Tennessee, Univ. of California Berkeley,
 *     Univ. of Colorado Denver and NAG Ltd..
 *     December 2016
 * Copyright (C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 * *************************************************************************/

#pragma once

#include "lapack_device_functions.hpp"
#include "rocblas.hpp"
#include "rocsolver_run_specialized_kernels.hpp"

ROCSOLVER_BEGIN_NAMESPACE

/*************************************************************
    Templated kernels are instantiated in separate cpp
    files in order to improve compilation times and reduce
    the library size.
*************************************************************/

template <int MAX_THDS, typename T, typename I, typename U>
ROCSOLVER_KERNEL void __launch_bounds__(MAX_THDS)
    larf_left_kernel_small(const I m,
                           const I n,
                           U xx,
                           const rocblas_stride shiftX,
                           const I incX,
                           const rocblas_stride strideX,
                           const T* tauA,
                           const rocblas_stride strideP,
                           U AA,
                           const rocblas_stride shiftA,
                           const I lda,
                           const rocblas_stride strideA)
{
    I bid = hipBlockIdx_x;
    I rid = hipThreadIdx_x;
    I cid = hipBlockIdx_y;

    // select batch instance
    T* x = load_ptr_batch<T>(xx, bid, shiftX, strideX);
    T* A = load_ptr_batch<T>(AA, bid, shiftA, strideA);
    const T* tau = tauA + bid * strideP;

    // shared variables
    __shared__ T sval[MAX_THDS];
    __shared__ T xs[LARF_SSKER_MAX_DIM];

    // load x into shared memory
    I start = (incX > 0 ? 0 : (m - 1) * -incX);
    for(I i = rid; i < m; i += MAX_THDS)
        xs[i] = x[start + i * incX];
    __syncthreads();

    for(I j = cid; j < n; j += LARF_SSKER_BLOCKS)
    {
        // gemv
        dot<MAX_THDS, true, T>(rid, m, xs, 1, A + j * lda, 1, sval);
        __syncthreads();

        // ger
        T temp = -tau[0] * conj(sval[0]);
        for(I i = rid; i < m; i += MAX_THDS)
        {
            A[i + j * lda] += temp * xs[i];
        }
    }
}

template <int MAX_THDS, typename T, typename I, typename U>
ROCSOLVER_KERNEL void __launch_bounds__(MAX_THDS)
    larf_right_kernel_small(const I m,
                            const I n,
                            U xx,
                            const rocblas_stride shiftX,
                            const I incX,
                            const rocblas_stride strideX,
                            const T* tauA,
                            const rocblas_stride strideP,
                            U AA,
                            const rocblas_stride shiftA,
                            const I lda,
                            const rocblas_stride strideA)
{
    I bid = hipBlockIdx_x;
    I cid = hipThreadIdx_x;
    I rid = hipBlockIdx_y;

    // select batch instance
    T* x = load_ptr_batch<T>(xx, bid, shiftX, strideX);
    T* A = load_ptr_batch<T>(AA, bid, shiftA, strideA);
    const T* tau = tauA + bid * strideP;

    // shared variables
    __shared__ T sval[MAX_THDS];
    __shared__ T xs[LARF_SSKER_MAX_DIM];

    // load x into shared memory
    I start = (incX > 0 ? 0 : (n - 1) * -incX);
    for(I j = cid; j < n; j += MAX_THDS)
        xs[j] = x[start + j * incX];
    __syncthreads();

    for(I i = rid; i < m; i += LARF_SSKER_BLOCKS)
    {
        // gemv
        dot<MAX_THDS, false, T>(cid, n, xs, 1, A + i, lda, sval);
        __syncthreads();

        // ger
        T temp = -tau[0] * sval[0];
        for(I j = cid; j < n; j += MAX_THDS)
        {
            A[i + j * lda] += temp * conj(xs[j]);
        }
    }
}

/*************************************************************
    Launchers of specialized  kernels
*************************************************************/

template <typename T, typename I, typename U>
rocblas_status larf_run_small(rocblas_handle handle,
                              const rocblas_side side,
                              const I m,
                              const I n,
                              U x,
                              const rocblas_stride shiftX,
                              const I incX,
                              const rocblas_stride strideX,
                              const T* tau,
                              const rocblas_stride strideP,
                              U A,
                              const rocblas_stride shiftA,
                              const I lda,
                              const rocblas_stride strideA,
                              const I batch_count)
{
    hipStream_t stream;
    rocblas_get_stream(handle, &stream);

    if(side == rocblas_side_left)
    {
        dim3 grid(batch_count, min(n, LARF_SSKER_BLOCKS), 1);

        if(m <= 64)
            ROCSOLVER_LAUNCH_KERNEL((larf_left_kernel_small<64, T>), grid, dim3(64), 0, stream, m, n,
                                    x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda, strideA);
        else if(m <= 128)
            ROCSOLVER_LAUNCH_KERNEL((larf_left_kernel_small<128, T>), grid, dim3(128), 0, stream, m,
                                    n, x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda,
                                    strideA);
        else if(m <= 256)
            ROCSOLVER_LAUNCH_KERNEL((larf_left_kernel_small<256, T>), grid, dim3(256), 0, stream, m,
                                    n, x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda,
                                    strideA);
        else if(m <= 512)
            ROCSOLVER_LAUNCH_KERNEL((larf_left_kernel_small<512, T>), grid, dim3(512), 0, stream, m,
                                    n, x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda,
                                    strideA);
        else
            ROCSOLVER_LAUNCH_KERNEL((larf_left_kernel_small<1024, T>), grid, dim3(1024), 0, stream,
                                    m, n, x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda,
                                    strideA);
    }
    else
    {
        dim3 grid(batch_count, min(m, LARF_SSKER_BLOCKS), 1);

        if(n <= 64)
            ROCSOLVER_LAUNCH_KERNEL((larf_right_kernel_small<64, T>), grid, dim3(64), 0, stream, m, n,
                                    x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda, strideA);
        else if(n <= 128)
            ROCSOLVER_LAUNCH_KERNEL((larf_right_kernel_small<128, T>), grid, dim3(128), 0, stream,
                                    m, n, x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda,
                                    strideA);
        else if(n <= 256)
            ROCSOLVER_LAUNCH_KERNEL((larf_right_kernel_small<256, T>), grid, dim3(256), 0, stream,
                                    m, n, x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda,
                                    strideA);
        else if(n <= 512)
            ROCSOLVER_LAUNCH_KERNEL((larf_right_kernel_small<512, T>), grid, dim3(512), 0, stream,
                                    m, n, x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda,
                                    strideA);
        else
            ROCSOLVER_LAUNCH_KERNEL((larf_right_kernel_small<1024, T>), grid, dim3(1024), 0, stream,
                                    m, n, x, shiftX, incX, strideX, tau, strideP, A, shiftA, lda,
                                    strideA);
    }

    return rocblas_status_success;
}

/*************************************************************
    Instantiation macros
*************************************************************/

#define INSTANTIATE_LARF_SMALL(T, I, U)                                                        \
    template rocblas_status larf_run_small<T, I, U>(                                           \
        rocblas_handle handle, const rocblas_side side, const I m, const I n, U x,             \
        const rocblas_stride shiftX, const I incX, const rocblas_stride strideX, const T* tau, \
        const rocblas_stride strideP, U A, const rocblas_stride shiftA, const I lda,           \
        const rocblas_stride strideA, const I batch_count)

ROCSOLVER_END_NAMESPACE
