//===-- RISCVRegisterBankInfo.cpp -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
/// \file
/// This file implements the targeting of the RegisterBankInfo class for RISC-V.
/// \todo This should be generated by TableGen.
//===----------------------------------------------------------------------===//

#include "RISCVRegisterBankInfo.h"
#include "MCTargetDesc/RISCVMCTargetDesc.h"
#include "RISCVSubtarget.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RegisterBank.h"
#include "llvm/CodeGen/RegisterBankInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"

#define GET_TARGET_REGBANK_IMPL
#include "RISCVGenRegisterBank.inc"

namespace llvm {
namespace RISCV {

const RegisterBankInfo::PartialMapping PartMappings[] = {
    {0, 32, GPRBRegBank},
    {0, 64, GPRBRegBank},
    {0, 32, FPRBRegBank},
    {0, 64, FPRBRegBank},
};

enum PartialMappingIdx {
  PMI_GPRB32 = 0,
  PMI_GPRB64 = 1,
  PMI_FPRB32 = 2,
  PMI_FPRB64 = 3,
};

const RegisterBankInfo::ValueMapping ValueMappings[] = {
    // Invalid value mapping.
    {nullptr, 0},
    // Maximum 3 GPR operands; 32 bit.
    {&PartMappings[PMI_GPRB32], 1},
    {&PartMappings[PMI_GPRB32], 1},
    {&PartMappings[PMI_GPRB32], 1},
    // Maximum 3 GPR operands; 64 bit.
    {&PartMappings[PMI_GPRB64], 1},
    {&PartMappings[PMI_GPRB64], 1},
    {&PartMappings[PMI_GPRB64], 1},
    // Maximum 3 FPR operands; 32 bit.
    {&PartMappings[PMI_FPRB32], 1},
    {&PartMappings[PMI_FPRB32], 1},
    {&PartMappings[PMI_FPRB32], 1},
    // Maximum 3 FPR operands; 64 bit.
    {&PartMappings[PMI_FPRB64], 1},
    {&PartMappings[PMI_FPRB64], 1},
    {&PartMappings[PMI_FPRB64], 1},
};

enum ValueMappingIdx {
  InvalidIdx = 0,
  GPRB32Idx = 1,
  GPRB64Idx = 4,
  FPRB32Idx = 7,
  FPRB64Idx = 10,
};
} // namespace RISCV
} // namespace llvm

using namespace llvm;

RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode)
    : RISCVGenRegisterBankInfo(HwMode) {}

const RegisterBank &
RISCVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
                                              LLT Ty) const {
  switch (RC.getID()) {
  default:
    llvm_unreachable("Register class not supported");
  case RISCV::GPRRegClassID:
  case RISCV::GPRF16RegClassID:
  case RISCV::GPRF32RegClassID:
  case RISCV::GPRNoX0RegClassID:
  case RISCV::GPRNoX0X2RegClassID:
  case RISCV::GPRJALRRegClassID:
  case RISCV::GPRTCRegClassID:
  case RISCV::GPRC_and_GPRTCRegClassID:
  case RISCV::GPRCRegClassID:
  case RISCV::GPRC_and_SR07RegClassID:
  case RISCV::SR07RegClassID:
  case RISCV::SPRegClassID:
  case RISCV::GPRX0RegClassID:
    return getRegBank(RISCV::GPRBRegBankID);
  case RISCV::FPR64RegClassID:
  case RISCV::FPR16RegClassID:
  case RISCV::FPR32RegClassID:
  case RISCV::FPR64CRegClassID:
  case RISCV::FPR32CRegClassID:
    return getRegBank(RISCV::FPRBRegBankID);
  case RISCV::VMRegClassID:
  case RISCV::VRRegClassID:
  case RISCV::VRNoV0RegClassID:
  case RISCV::VRM2RegClassID:
  case RISCV::VRM2NoV0RegClassID:
  case RISCV::VRM4RegClassID:
  case RISCV::VRM4NoV0RegClassID:
  case RISCV::VMV0RegClassID:
  case RISCV::VRM2_with_sub_vrm1_0_in_VMV0RegClassID:
  case RISCV::VRM4_with_sub_vrm1_0_in_VMV0RegClassID:
  case RISCV::VRM8RegClassID:
  case RISCV::VRM8NoV0RegClassID:
  case RISCV::VRM8_with_sub_vrm1_0_in_VMV0RegClassID:
    return getRegBank(RISCV::VRBRegBankID);
  }
}

static const RegisterBankInfo::ValueMapping *getFPValueMapping(unsigned Size) {
  assert(Size == 32 || Size == 64);
  unsigned Idx = Size == 64 ? RISCV::FPRB64Idx : RISCV::FPRB32Idx;
  return &RISCV::ValueMappings[Idx];
}

/// Returns whether opcode \p Opc is a pre-isel generic floating-point opcode,
/// having only floating-point operands.
/// FIXME: this is copied from target AArch64. Needs some code refactor here to
/// put this function in GlobalISel/Utils.cpp.
static bool isPreISelGenericFloatingPointOpcode(unsigned Opc) {
  switch (Opc) {
  case TargetOpcode::G_FADD:
  case TargetOpcode::G_FSUB:
  case TargetOpcode::G_FMUL:
  case TargetOpcode::G_FMA:
  case TargetOpcode::G_FDIV:
  case TargetOpcode::G_FCONSTANT:
  case TargetOpcode::G_FPEXT:
  case TargetOpcode::G_FPTRUNC:
  case TargetOpcode::G_FCEIL:
  case TargetOpcode::G_FFLOOR:
  case TargetOpcode::G_FNEARBYINT:
  case TargetOpcode::G_FNEG:
  case TargetOpcode::G_FCOPYSIGN:
  case TargetOpcode::G_FCOS:
  case TargetOpcode::G_FSIN:
  case TargetOpcode::G_FLOG10:
  case TargetOpcode::G_FLOG:
  case TargetOpcode::G_FLOG2:
  case TargetOpcode::G_FSQRT:
  case TargetOpcode::G_FABS:
  case TargetOpcode::G_FEXP:
  case TargetOpcode::G_FRINT:
  case TargetOpcode::G_INTRINSIC_TRUNC:
  case TargetOpcode::G_INTRINSIC_ROUND:
  case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
  case TargetOpcode::G_FMAXNUM:
  case TargetOpcode::G_FMINNUM:
  case TargetOpcode::G_FMAXIMUM:
  case TargetOpcode::G_FMINIMUM:
    return true;
  }
  return false;
}

// TODO: Make this more like AArch64?
bool RISCVRegisterBankInfo::hasFPConstraints(
    const MachineInstr &MI, const MachineRegisterInfo &MRI,
    const TargetRegisterInfo &TRI) const {
  if (isPreISelGenericFloatingPointOpcode(MI.getOpcode()))
    return true;

  // If we have a copy instruction, we could be feeding floating point
  // instructions.
  if (MI.getOpcode() != TargetOpcode::COPY)
    return false;

  return getRegBank(MI.getOperand(0).getReg(), MRI, TRI) == &RISCV::FPRBRegBank;
}

bool RISCVRegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
                                       const MachineRegisterInfo &MRI,
                                       const TargetRegisterInfo &TRI) const {
  switch (MI.getOpcode()) {
  case TargetOpcode::G_FPTOSI:
  case TargetOpcode::G_FPTOUI:
  case TargetOpcode::G_FCMP:
    return true;
  default:
    break;
  }

  return hasFPConstraints(MI, MRI, TRI);
}

bool RISCVRegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
                                          const MachineRegisterInfo &MRI,
                                          const TargetRegisterInfo &TRI) const {
  switch (MI.getOpcode()) {
  case TargetOpcode::G_SITOFP:
  case TargetOpcode::G_UITOFP:
    return true;
  default:
    break;
  }

  return hasFPConstraints(MI, MRI, TRI);
}

bool RISCVRegisterBankInfo::anyUseOnlyUseFP(
    Register Def, const MachineRegisterInfo &MRI,
    const TargetRegisterInfo &TRI) const {
  return any_of(
      MRI.use_nodbg_instructions(Def),
      [&](const MachineInstr &UseMI) { return onlyUsesFP(UseMI, MRI, TRI); });
}

const RegisterBankInfo::InstructionMapping &
RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
  const unsigned Opc = MI.getOpcode();

  // Try the default logic for non-generic instructions that are either copies
  // or already have some operands assigned to banks.
  if (!isPreISelGenericOpcode(Opc) || Opc == TargetOpcode::G_PHI) {
    const InstructionMapping &Mapping = getInstrMappingImpl(MI);
    if (Mapping.isValid())
      return Mapping;
  }

  const MachineFunction &MF = *MI.getParent()->getParent();
  const MachineRegisterInfo &MRI = MF.getRegInfo();
  const TargetSubtargetInfo &STI = MF.getSubtarget();
  const TargetRegisterInfo &TRI = *STI.getRegisterInfo();

  unsigned GPRSize = getMaximumSize(RISCV::GPRBRegBankID);
  assert((GPRSize == 32 || GPRSize == 64) && "Unexpected GPR size");

  unsigned NumOperands = MI.getNumOperands();
  const ValueMapping *GPRValueMapping =
      &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
                                          : RISCV::GPRB32Idx];

  switch (Opc) {
  case TargetOpcode::G_ADD:
  case TargetOpcode::G_SUB:
  case TargetOpcode::G_SHL:
  case TargetOpcode::G_ASHR:
  case TargetOpcode::G_LSHR:
  case TargetOpcode::G_AND:
  case TargetOpcode::G_OR:
  case TargetOpcode::G_XOR:
  case TargetOpcode::G_MUL:
  case TargetOpcode::G_SDIV:
  case TargetOpcode::G_SREM:
  case TargetOpcode::G_SMULH:
  case TargetOpcode::G_SMAX:
  case TargetOpcode::G_SMIN:
  case TargetOpcode::G_UDIV:
  case TargetOpcode::G_UREM:
  case TargetOpcode::G_UMULH:
  case TargetOpcode::G_UMAX:
  case TargetOpcode::G_UMIN:
  case TargetOpcode::G_PTR_ADD:
  case TargetOpcode::G_PTRTOINT:
  case TargetOpcode::G_INTTOPTR:
  case TargetOpcode::G_TRUNC:
  case TargetOpcode::G_ANYEXT:
  case TargetOpcode::G_SEXT:
  case TargetOpcode::G_ZEXT:
  case TargetOpcode::G_SEXTLOAD:
  case TargetOpcode::G_ZEXTLOAD:
    return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
                                 NumOperands);
  case TargetOpcode::G_FADD:
  case TargetOpcode::G_FSUB:
  case TargetOpcode::G_FMUL:
  case TargetOpcode::G_FDIV:
  case TargetOpcode::G_FABS:
  case TargetOpcode::G_FNEG:
  case TargetOpcode::G_FSQRT:
  case TargetOpcode::G_FMAXNUM:
  case TargetOpcode::G_FMINNUM: {
    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
    return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
                                 getFPValueMapping(Ty.getSizeInBits()),
                                 NumOperands);
  }
  case TargetOpcode::G_IMPLICIT_DEF: {
    Register Dst = MI.getOperand(0).getReg();
    auto Mapping = GPRValueMapping;
    // FIXME: May need to do a better job determining when to use FPRB.
    // For example, the look through COPY case:
    // %0:_(s32) = G_IMPLICIT_DEF
    // %1:_(s32) = COPY %0
    // $f10_d = COPY %1(s32)
    if (anyUseOnlyUseFP(Dst, MRI, TRI))
      Mapping = getFPValueMapping(MRI.getType(Dst).getSizeInBits());
    return getInstructionMapping(DefaultMappingID, /*Cost=*/1, Mapping,
                                 NumOperands);
  }
  }

  SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);

  switch (Opc) {
  case TargetOpcode::G_LOAD: {
    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
    OpdsMapping[0] = GPRValueMapping;
    OpdsMapping[1] = GPRValueMapping;
    // Use FPR64 for s64 loads on rv32.
    if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
      assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
      break;
    }

    // Check if that load feeds fp instructions.
    // In that case, we want the default mapping to be on FPR
    // instead of blind map every scalar to GPR.
    if (anyUseOnlyUseFP(MI.getOperand(0).getReg(), MRI, TRI))
      // If we have at least one direct use in a FP instruction,
      // assume this was a floating point load in the IR. If it was
      // not, we would have had a bitcast before reaching that
      // instruction.
      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());

    break;
  }
  case TargetOpcode::G_STORE: {
    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
    OpdsMapping[0] = GPRValueMapping;
    OpdsMapping[1] = GPRValueMapping;
    // Use FPR64 for s64 stores on rv32.
    if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
      assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
      break;
    }

    MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
    if (onlyDefinesFP(*DefMI, MRI, TRI))
      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
    break;
  }
  case TargetOpcode::G_SELECT: {
    LLT Ty = MRI.getType(MI.getOperand(0).getReg());

    // Try to minimize the number of copies. If we have more floating point
    // constrained values than not, then we'll put everything on FPR. Otherwise,
    // everything has to be on GPR.
    unsigned NumFP = 0;

    // Use FPR64 for s64 select on rv32.
    if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
      NumFP = 3;
    } else {
      // Check if the uses of the result always produce floating point values.
      //
      // For example:
      //
      // %z = G_SELECT %cond %x %y
      // fpr = G_FOO %z ...
      if (any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()),
                 [&](const MachineInstr &UseMI) {
                   return onlyUsesFP(UseMI, MRI, TRI);
                 }))
        ++NumFP;

      // Check if the defs of the source values always produce floating point
      // values.
      //
      // For example:
      //
      // %x = G_SOMETHING_ALWAYS_FLOAT %a ...
      // %z = G_SELECT %cond %x %y
      //
      // Also check whether or not the sources have already been decided to be
      // FPR. Keep track of this.
      //
      // This doesn't check the condition, since the condition is always an
      // integer.
      for (unsigned Idx = 2; Idx < 4; ++Idx) {
        Register VReg = MI.getOperand(Idx).getReg();
        MachineInstr *DefMI = MRI.getVRegDef(VReg);
        if (getRegBank(VReg, MRI, TRI) == &RISCV::FPRBRegBank ||
            onlyDefinesFP(*DefMI, MRI, TRI))
          ++NumFP;
      }
    }

    // Condition operand is always GPR.
    OpdsMapping[1] = GPRValueMapping;

    const ValueMapping *Mapping = GPRValueMapping;
    if (NumFP >= 2)
      Mapping = getFPValueMapping(Ty.getSizeInBits());

    OpdsMapping[0] = OpdsMapping[2] = OpdsMapping[3] = Mapping;
    break;
  }
  case TargetOpcode::G_FPTOSI:
  case TargetOpcode::G_FPTOUI:
  case RISCV::G_FCLASS: {
    LLT Ty = MRI.getType(MI.getOperand(1).getReg());
    OpdsMapping[0] = GPRValueMapping;
    OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
    break;
  }
  case TargetOpcode::G_SITOFP:
  case TargetOpcode::G_UITOFP: {
    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
    OpdsMapping[1] = GPRValueMapping;
    break;
  }
  case TargetOpcode::G_FCMP: {
    LLT Ty = MRI.getType(MI.getOperand(2).getReg());

    unsigned Size = Ty.getSizeInBits();
    assert((Size == 32 || Size == 64) && "Unsupported size for G_FCMP");

    OpdsMapping[0] = GPRValueMapping;
    OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
    break;
  }
  case TargetOpcode::G_MERGE_VALUES: {
    // Use FPR64 for s64 merge on rv32.
    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
    if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
      assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
      OpdsMapping[1] = GPRValueMapping;
      OpdsMapping[2] = GPRValueMapping;
    }
    break;
  }
  case TargetOpcode::G_UNMERGE_VALUES: {
    // Use FPR64 for s64 unmerge on rv32.
    LLT Ty = MRI.getType(MI.getOperand(2).getReg());
    if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
      assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
      OpdsMapping[0] = GPRValueMapping;
      OpdsMapping[1] = GPRValueMapping;
      OpdsMapping[2] = getFPValueMapping(Ty.getSizeInBits());
    }
    break;
  }
  default:
    // By default map all scalars to GPR.
    for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
       auto &MO = MI.getOperand(Idx);
       if (!MO.isReg() || !MO.getReg())
         continue;
       LLT Ty = MRI.getType(MO.getReg());
       if (!Ty.isValid())
         continue;

       if (isPreISelGenericFloatingPointOpcode(Opc))
         OpdsMapping[Idx] = getFPValueMapping(Ty.getSizeInBits());
       else
         OpdsMapping[Idx] = GPRValueMapping;
    }
    break;
  }

  return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
                               getOperandsMapping(OpdsMapping), NumOperands);
}
