/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Joachim Breitner
-/
module

prelude
public import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Eqns
import Lean.Meta.Tactic.Apply
import Lean.Meta.Tactic.Split
public import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.Simp.BuiltinSimprocs

/-!
This module is responsible for proving the unfolding equation for functions defined
by well-founded recursion. It uses `WellFounded.fix_eq`, and then has to undo
the changes to matchers that `WF.Fix` did using `MatcherApp.addArg`.

This is done using a single-pass `simp` traversal of the expression that looks
for expressions that were modified that way, and rewrites them back using the
rather specialized `_arg_pusher` theorem that is generated by `mkMatchArgPusher`.
-/

namespace Lean.Elab.WF
open Meta
open Eqns

def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
  let target ← mvarId.getType'
  let some (_, lhs, rhs) := target.eq? | unreachable!

  -- lhs should be an application of the declNameNonrec, which unfolds to an
  -- application of fix in one step
  let some lhs' ← delta? lhs | throwError "rwFixEq: cannot delta-reduce {lhs}"
  let_expr WellFounded.fix _α _C _r _hwf F x := lhs'
    | throwTacticEx `rwFixEq mvarId "expected saturated fixed-point application in {lhs'}"
  let h := mkAppN (mkConst ``WellFounded.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs

  -- We used to just rewrite with `fix_eq` and continue with whatever RHS that produces, but that
  -- would include more copies of `fix` resulting in large and confusing terms.
  -- Instead we manually construct the new term in terms of the current functions,
  -- which should be headed by the `declNameNonRec`, and should be defeq to the expected type

  -- if lhs == e x and lhs' == fix .., then lhsNew := e x = F x (fun y _ => e y)
  let ftype := (← inferType (mkApp F x)).bindingDomain!
  let f' ← forallBoundedTelescope ftype (some 2) fun ys _ => do
    mkLambdaFVars ys (.app lhs.appFn! ys[0]!)
  let lhsNew := mkApp2 F x f'
  let targetNew ← mkEq lhsNew rhs
  let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
  mvarId.assign (← mkEqTrans h mvarNew)
  return mvarNew.mvarId!

def isForallMotive (matcherApp : MatcherApp) : MetaM (Option Expr) := do
  lambdaBoundedTelescope matcherApp.motive matcherApp.discrs.size fun xs t =>
    if xs.size == matcherApp.discrs.size && t.isForall && !t.bindingBody!.hasLooseBVar 0 then
      return some (← mkLambdaFVars xs t.bindingBody!)
    else
      return none


/-- Generalization of `splitMatch` that can handle `casesOn` -/
def splitMatchOrCasesOn (mvarId : MVarId) (e : Expr) (matcherInfo : MatcherInfo) : MetaM (List MVarId) := do
  if (← isMatcherApp e) then
    Split.splitMatch mvarId e
  else
    assert! matcherInfo.numDiscrs = 1
    let discr := e.getAppArgs[matcherInfo.numParams + 1]!
    assert! discr.isFVar
    let subgoals ← mvarId.cases discr.fvarId!
    return subgoals.map (·.mvarId) |>.toList

/--
Generates a theorem of the form
```
matcherArgPusher params motive {α} {β} (f : ∀ (x : α), β x) rel alt1 .. x1 x2
  :
  matcher params (motive := fun x1 x2 => ((y : α) → rel x1 x2 y → β y) → motive x1 x2)
    (alt1 := fun z1 z2 z2 f => alt1 z1 z2 z2 f) …
    x1 x2
    (fun y _h => f y)
  =
  matcher params (motive := motive)
    (alt1 := fun z1 z2 z2 => alt1 z1 z2 z2 (fun y _ => f y)) …
    x1 x2
```
-/
def mkMatchArgPusher (matcherName : Name) (matcherInfo : MatcherInfo) : MetaM Name := do
  let name := (mkPrivateName (← getEnv) matcherName) ++ `_arg_pusher
  realizeConst matcherName name do
    let matcherVal ← getConstVal matcherName
    forallBoundedTelescope matcherVal.type (some (matcherInfo.numParams + 1)) fun xs _ => do
      let params := xs[*...matcherInfo.numParams]
      let motive' := xs[matcherInfo.numParams]!
      let u ← mkFreshUserName `u
      let v ← mkFreshUserName `v
      withLocalDeclD `α (.sort (.param u)) fun alpha => do
      withLocalDeclD `β (← mkArrow alpha (.sort (.param v))) fun beta => do
      withLocalDeclD `f (.forallE `x alpha (mkApp beta (.bvar 0)) .default) fun f => do
      let relType ← forallTelescope (← inferType motive') fun xs _ =>
        mkForallFVars xs (.forallE `x alpha (.sort 0) .default)
      withLocalDeclD `rel relType fun rel => do
      let motive ← forallTelescope (← inferType motive') fun xs _ => do
        let motiveBody := mkAppN motive' xs
        let extraArgType := .forallE `y alpha (.forallE `h (mkAppN rel (xs.push (.bvar 0))) (mkApp beta (.bvar 1)) .default) .default
        let motiveBody ← mkArrow extraArgType motiveBody
        mkLambdaFVars xs motiveBody

      let uElim ← lambdaBoundedTelescope motive matcherInfo.numDiscrs fun _ motiveBody => do
        getLevel motiveBody
      let us := matcherVal.levelParams ++ [u, v]
      let matcherLevels' := matcherVal.levelParams.map mkLevelParam
      let matcherLevels ← match matcherInfo.uElimPos? with
        | none     =>
          unless uElim.isZero do
            throwError "unexpected matcher application for {.ofConstName matcherName}, motive is not a proposition"
          pure matcherLevels'
        | some pos =>
          pure <| (matcherLevels'.toArray.set! pos uElim).toList
      let lhs := .const matcherName matcherLevels
      let rhs := .const matcherName matcherLevels'
      let lhs := mkAppN lhs params
      let rhs := mkAppN rhs params
      let lhs := mkApp lhs motive
      let rhs := mkApp rhs motive'
      forallBoundedTelescope (← inferType lhs) matcherInfo.numDiscrs fun discrs _ => do
      let lhs := mkAppN lhs discrs
      let rhs := mkAppN rhs discrs
      forallBoundedTelescope (← inferType lhs) matcherInfo.numAlts fun alts _ => do
      let lhs := mkAppN lhs alts

      let mut rhs := rhs
      for alt in alts, altNumParams in matcherInfo.altNumParams do
        let alt' ← forallBoundedTelescope (← inferType alt) altNumParams fun ys altBodyType => do
          assert! altBodyType.isForall
          let altArg ← forallBoundedTelescope altBodyType.bindingDomain! (some 2) fun ys _ => do
            mkLambdaFVars ys (.app f ys[0]!)
          mkLambdaFVars ys (mkAppN alt (ys.push altArg))
        rhs := mkApp rhs alt'

      let extraArg := .lam `y alpha (.lam `h (mkAppN rel (discrs.push (.bvar 0))) (mkApp f (.bvar 1)) .default) .default
      let lhs := mkApp lhs extraArg
      let goal ← mkEq lhs rhs

      let value ← mkFreshExprSyntheticOpaqueMVar goal
      let mvarId := value.mvarId!
      let mvarIds ← splitMatchOrCasesOn mvarId rhs matcherInfo
      for mvarId in mvarIds do
        mvarId.refl
      let value ← instantiateMVars value
      let type ← mkForallFVars (params ++ #[motive', alpha, beta, f, rel] ++ discrs ++ alts) goal
      let value ← mkLambdaFVars (params ++ #[motive', alpha, beta, f, rel] ++ discrs ++ alts) value
      addDecl <| Declaration.thmDecl { name, levelParams := us, type, value}
  return name

builtin_simproc_decl matcherPushArg (_) := fun e => do
  let e := e.headBeta
  let some matcherApp ← matchMatcherApp? e (alsoCasesOn := true) | return .continue
  -- Check that the first remaining argument is of the form `(fun (x : α) p => (f x : β x))`
  let some fArg := matcherApp.remaining[0]? | return .continue
  unless fArg.isLambda do return .continue
  unless fArg.bindingBody!.isLambda do return .continue
  unless fArg.bindingBody!.bindingBody!.isApp do return .continue
  if fArg.bindingBody!.bindingBody!.hasLooseBVar 0 then return .continue
  unless fArg.bindingBody!.bindingBody!.appArg! == .bvar 1 do return .continue
  if fArg.bindingBody!.bindingBody!.appFn!.hasLooseBVar 1 then return .continue

  let fExpr := fArg.bindingBody!.bindingBody!.appFn!
  let fExprType ← inferType fExpr
  let fExprType ← withTransparency .all (whnfForall fExprType)
  assert! fExprType.isForall
  let alpha := fExprType.bindingDomain!
  let beta := .lam fExprType.bindingName! fExprType.bindingDomain! fExprType.bindingBody! .default

  -- Check that the motive has an extra parameter (from MatcherApp.addArg)
  let some motive' ← isForallMotive matcherApp |  return .continue
  let rel ← lambdaTelescope matcherApp.motive fun xs motiveBody =>
    let motiveBodyArg := motiveBody.bindingDomain!
    mkLambdaFVars xs (.lam motiveBodyArg.bindingName! motiveBodyArg.bindingDomain! motiveBodyArg.bindingBody!.bindingDomain! .default)

  let argPusher ← mkMatchArgPusher matcherApp.matcherName matcherApp.toMatcherInfo
  -- Let's infer the level paramters:
  let proof ← withTransparency .all <| mkAppOptM
    argPusher ((matcherApp.params ++ #[motive', alpha, beta, fExpr, rel] ++ matcherApp.discrs ++ matcherApp.alts).map some)
  let some (_, _, rhs) := (← inferType proof).eq? | throwError "matcherPushArg: expected equality:{indentExpr (← inferType proof)}"
  let step : Simp.Result := { expr := rhs, proof? := some proof }
  let step ← step.addExtraArgs matcherApp.remaining[1...*]
  return .continue (some step)

def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := withTransparency .all do
  let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none, letToHave := false, singlePass := true })
  let simprocs := ({} : Simp.SimprocsArray)
  let simprocs ← simprocs.add ``matcherPushArg (post := false)
  match (← simpTarget mvarId ctx (simprocs := simprocs)).1 with
  | none => return ()
  | some mvarId' =>
    prependError m!"failed to finish proof for equational theorem for '{.ofConstName declName}'" do
      mvarId'.refl

public def mkUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) (wfPreprocessProof : Simp.Result) : MetaM Unit := do
  let name := mkEqLikeNameFor (← getEnv) preDef.declName unfoldThmSuffix
  prependError m!"Cannot derive unfold equation {name}" do
  withOptions (tactic.hygienic.set · false) do
  withoutExporting do
    lambdaTelescope preDef.value fun xs body => do
      let us := preDef.levelParams.map mkLevelParam
      let lhs := mkAppN (Lean.mkConst preDef.declName us) xs
      let type ← mkEq lhs body

      let main ← mkFreshExprSyntheticOpaqueMVar type
      let mvarId := main.mvarId!
      let wfPreprocessProof ← Simp.mkCongr { expr := type.appFn! } (← wfPreprocessProof.addExtraArgs xs)
      let mvarId ← applySimpResultToTarget mvarId type wfPreprocessProof
      let mvarId ← if preDef.declName != unaryPreDefName then deltaLHS mvarId else pure mvarId
      let mvarId ← rwFixEq mvarId
      mkUnfoldProof preDef.declName mvarId

      let value ← instantiateMVars main
      let type ← mkForallFVars xs type
      let type ← letToHave type
      let value ← mkLambdaFVars xs value
      addDecl <| Declaration.thmDecl {
        name, type, value
        levelParams := preDef.levelParams
      }
      inferDefEqAttr name
      trace[Elab.definition.wf] "mkUnfoldEq defined {.ofConstName name}"

/--
Derives the equational theorem for the individual functions from the equational
theorem of `foo._unary` or `foo._binary`.

It should just be a specialization of that one, due to defeq.
-/
public def mkBinaryUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) : MetaM Unit := do
  let name := mkEqLikeNameFor (← getEnv) preDef.declName unfoldThmSuffix
  let unaryEqName:= mkEqLikeNameFor (← getEnv) unaryPreDefName unfoldThmSuffix
  prependError m!"Cannot derive {name} from {unaryEqName}" do
  withOptions (tactic.hygienic.set · false) do
    lambdaTelescope preDef.value fun xs body => do
      let us := preDef.levelParams.map mkLevelParam
      let lhs := mkAppN (Lean.mkConst preDef.declName us) xs
      let type ← mkEq lhs body
      let main ← mkFreshExprSyntheticOpaqueMVar type
      let mvarId := main.mvarId!
      let mvarId ← deltaLHS mvarId -- unfold the function
      let mvarIds ← mvarId.applyConst unaryEqName
      unless mvarIds.isEmpty do
        throwError "Failed to apply '{unaryEqName}' to '{mvarId}'"

      let value ← instantiateMVars main
      let type ← mkForallFVars xs type
      let type ← letToHave type
      let value ← mkLambdaFVars xs value
      addDecl <| Declaration.thmDecl {
        name, type, value
        levelParams := preDef.levelParams
      }
      inferDefEqAttr name
      trace[Elab.definition.wf] "mkBinaryUnfoldEq defined {.ofConstName name}"

builtin_initialize
  registerTraceClass `Elab.definition.wf.eqns

end Lean.Elab.WF
