// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier:  MIT

#include "PlatformUtils.hpp"

#if defined(__linux__)

#include "HipdnnException.hpp"
#include <dlfcn.h>

namespace hipdnn_backend::platform_utilities
{

std::filesystem::path getCurrentModuleDirectory()
{
    std::filesystem::path modulePath;

    Dl_info info;
    if(dladdr(reinterpret_cast<void const*>(&getCurrentModuleDirectory), &info) != 0
       && info.dli_fname != nullptr && info.dli_fname[0] != '\0')
    {
        modulePath = std::filesystem::path(info.dli_fname).parent_path();
    }
    else
    {
        throw HipdnnException(HIPDNN_STATUS_INTERNAL_ERROR, "Failed to get module file name.");
    }

    return std::filesystem::weakly_canonical(std::filesystem::absolute(modulePath));
}

PluginLibHandle openLibrary(const std::filesystem::path& libraryPath)
{
// Use RTLD_LAZY because librocblas.so currently contains an undefined symbol
// that is not resolved at runtime, leading to dlopen failure with RTLD_NOW.
// When this is fixed in rocblas, we should ideally switch back to RTLD_NOW.
#if __has_feature(address_sanitizer)
    // Address Sanitizer does not support RTLD_DEEPBIND, so we use RTLD_LAZY only
    PluginLibHandle handle = dlopen(libraryPath.string().c_str(), RTLD_LAZY);
#else
    PluginLibHandle handle = dlopen(libraryPath.string().c_str(), RTLD_LAZY | RTLD_DEEPBIND);
#endif

    if(handle == nullptr)
    {
        const char* error = dlerror();
        throw HipdnnException(HIPDNN_STATUS_BAD_PARAM,
                              "Failed to load library: " + libraryPath.string() + " (Error: "
                                  + (error != nullptr ? std::string(error) : "Unknown error")
                                  + ")");
    }

    return handle;
}

void closeLibrary(PluginLibHandle handle)
{
    dlclose(handle);
}

void* getSymbol(PluginLibHandle handle, const char* symbolName)
{
    void* symbol = dlsym(handle, symbolName);
    if(symbol == nullptr)
    {
        const char* error = dlerror();
        throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR,
                              "Failed to get symbol: " + std::string(symbolName) + " (Error: "
                                  + (error != nullptr ? error : "Unknown error") + ")");
    }
    return symbol;
}

}

#endif // defined(__linux__)
