/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.ast.analysis;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.AstNodeUtils;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.analysis.FieldResolutionContext;
import org.opensearch.sql.ast.analysis.FieldResolutionResult;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.PatternMode;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.AddColTotals;
import org.opensearch.sql.ast.tree.AddTotals;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Append;
import org.opensearch.sql.ast.tree.AppendCol;
import org.opensearch.sql.ast.tree.AppendPipe;
import org.opensearch.sql.ast.tree.Bin;
import org.opensearch.sql.ast.tree.Chart;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Expand;
import org.opensearch.sql.ast.tree.FillNull;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Flatten;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Join;
import org.opensearch.sql.ast.tree.Lookup;
import org.opensearch.sql.ast.tree.Multisearch;
import org.opensearch.sql.ast.tree.MvCombine;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Patterns;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Regex;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Replace;
import org.opensearch.sql.ast.tree.Reverse;
import org.opensearch.sql.ast.tree.Rex;
import org.opensearch.sql.ast.tree.SPath;
import org.opensearch.sql.ast.tree.Search;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.StreamWindow;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.Transpose;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.ast.tree.Window;
import org.opensearch.sql.calcite.utils.WildcardUtils;
import org.opensearch.sql.expression.parse.RegexCommonUtils;

public class FieldResolutionVisitor
extends AbstractNodeVisitor<Node, FieldResolutionContext> {
    private static final String ALL_FIELDS = "*";

    public Map<UnresolvedPlan, FieldResolutionResult> analyze(UnresolvedPlan plan) {
        FieldResolutionContext context = new FieldResolutionContext();
        this.acceptAndVerifyNodeVisited(plan, context);
        return context.getResults();
    }

    @Override
    public Node visitChildren(Node node, FieldResolutionContext context) {
        for (Node node2 : node.getChild()) {
            this.acceptAndVerifyNodeVisited(node2, context);
        }
        return null;
    }

    private void acceptAndVerifyNodeVisited(Node node, FieldResolutionContext context) {
        Node result = node.accept(this, context);
        if (result != node) {
            throw new IllegalArgumentException("Unsupported command for field resolution: " + node.getClass().getSimpleName());
        }
    }

    @Override
    public Node visitProject(Project node, FieldResolutionContext context) {
        boolean isSingleSelectAll;
        boolean bl = isSingleSelectAll = node.getProjectList().size() == 1 && node.getProjectList().get(0) instanceof AllFields;
        if (isSingleSelectAll) {
            this.visitChildren((Node)node, context);
        } else {
            HashSet<String> projectFields = new HashSet<String>();
            HashSet<String> wildcardPatterns = new HashSet<String>();
            for (UnresolvedExpression expr : node.getProjectList()) {
                this.extractFieldsFromExpression(expr).forEach(field -> {
                    if (WildcardUtils.containsWildcard(field)) {
                        wildcardPatterns.add((String)field);
                    } else {
                        projectFields.add((String)field);
                    }
                });
            }
            FieldResolutionResult current = context.getCurrentRequirements();
            context.pushRequirements(current.and(new FieldResolutionResult(projectFields, wildcardPatterns)));
            this.visitChildren((Node)node, context);
            context.popRequirements();
        }
        return node;
    }

    @Override
    public Node visitFilter(Filter node, FieldResolutionContext context) {
        Set<String> filterFields = this.extractFieldsFromExpression(node.getCondition());
        if (AstNodeUtils.containsSubqueryExpression(node.getCondition())) {
            throw new IllegalArgumentException("Filter by subquery is not supported with field resolution.");
        }
        context.pushRequirements(context.getCurrentRequirements().or(filterFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitAggregation(Aggregation node, FieldResolutionContext context) {
        HashSet<String> aggFields = new HashSet<String>();
        for (UnresolvedExpression groupExpr : node.getGroupExprList()) {
            aggFields.addAll(this.extractFieldsFromExpression(groupExpr));
        }
        if (node.getSpan() != null) {
            aggFields.addAll(this.extractFieldsFromExpression(node.getSpan()));
        }
        for (UnresolvedExpression aggExpr : node.getAggExprList()) {
            aggFields.addAll(this.extractFieldsFromAggregation(aggExpr));
        }
        context.pushRequirements(new FieldResolutionResult(aggFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitSpath(SPath node, FieldResolutionContext context) {
        if (node.getPath() != null) {
            return this.visitEval(node.rewriteAsEval(), context);
        }
        FieldResolutionResult requirements = context.getCurrentRequirements();
        context.setResult(node, requirements);
        if (requirements.hasPartialWildcards()) {
            throw new IllegalArgumentException("Spath command cannot be used with partial wildcard such as `prefix*`.");
        }
        context.pushRequirements(requirements.or(Set.of(node.getInField())));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitSort(Sort node, FieldResolutionContext context) {
        HashSet<String> sortFields = new HashSet<String>();
        for (Field sortField : node.getSortList()) {
            sortFields.addAll(this.extractFieldsFromExpression(sortField));
        }
        context.pushRequirements(context.getCurrentRequirements().or(sortFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitEval(Eval node, FieldResolutionContext context) {
        HashSet<String> evalInputFields = new HashSet<String>();
        HashSet<String> computedFields = new HashSet<String>();
        for (Let letExpr : node.getExpressionList()) {
            evalInputFields.addAll(this.extractFieldsFromExpression(letExpr.getExpression()));
            computedFields.add(letExpr.getVar().getField().toString());
        }
        FieldResolutionResult currentReq = context.getCurrentRequirements();
        HashSet<String> allRequiredFields = new HashSet<String>(currentReq.getRegularFields());
        allRequiredFields.removeAll(computedFields);
        allRequiredFields.addAll(evalInputFields);
        context.pushRequirements(new FieldResolutionResult(allRequiredFields, currentReq.getWildcard()));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    private Set<String> extractFieldsFromExpression(UnresolvedExpression expr) {
        HashSet<String> fields = new HashSet<String>();
        if (expr == null) {
            return fields;
        }
        if (expr instanceof Field) {
            Field field = (Field)expr;
            fields.add(field.getField().toString());
        } else if (expr instanceof AllFields) {
            fields.add(ALL_FIELDS);
        } else if (expr instanceof QualifiedName) {
            QualifiedName name = (QualifiedName)expr;
            fields.add(name.toString());
        } else if (expr instanceof Alias) {
            Alias alias = (Alias)expr;
            fields.addAll(this.extractFieldsFromExpression(alias.getDelegated()));
        } else if (expr instanceof Function) {
            Function function = (Function)expr;
            for (UnresolvedExpression unresolvedExpression : function.getFuncArgs()) {
                fields.addAll(this.extractFieldsFromExpression(unresolvedExpression));
            }
        } else if (expr instanceof Span) {
            Span span = (Span)expr;
            fields.addAll(this.extractFieldsFromExpression(span.getField()));
        } else {
            if (expr instanceof Literal) {
                return fields;
            }
            for (Node node : expr.getChild()) {
                if (!(node instanceof UnresolvedExpression)) continue;
                UnresolvedExpression childExpr = (UnresolvedExpression)node;
                fields.addAll(this.extractFieldsFromExpression(childExpr));
            }
        }
        return fields;
    }

    @Override
    public Node visitJoin(Join node, FieldResolutionContext context) {
        HashSet<String> joinFields = new HashSet<String>();
        if (node.getJoinCondition().isPresent()) {
            joinFields.addAll(this.extractFieldsFromExpression(node.getJoinCondition().get()));
        }
        if (node.getJoinFields().isPresent()) {
            for (Field field : node.getJoinFields().get()) {
                joinFields.addAll(this.extractFieldsFromExpression(field));
            }
        }
        FieldResolutionResult currentReq = context.getCurrentRequirements();
        HashSet<String> baseRequiredFields = new HashSet<String>(currentReq.getRegularFields());
        String leftAlias = node.getLeftAlias().orElse(null);
        String rightAlias = node.getRightAlias().orElse(null);
        Set<String> leftFields = this.collectFieldsByAlias(baseRequiredFields, leftAlias, rightAlias);
        leftFields.addAll(this.collectFieldsByAlias(joinFields, leftAlias, rightAlias));
        Set<String> rightFields = this.collectFieldsByAlias(baseRequiredFields, rightAlias, leftAlias);
        rightFields.addAll(this.collectFieldsByAlias(joinFields, rightAlias, leftAlias));
        if (node.getLeft() != null) {
            context.pushRequirements(new FieldResolutionResult(leftFields, currentReq.getWildcard()));
            node.getLeft().accept(this, context);
            context.popRequirements();
        }
        if (node.getRight() != null) {
            context.pushRequirements(new FieldResolutionResult(rightFields, currentReq.getWildcard()));
            node.getRight().accept(this, context);
            context.popRequirements();
        }
        return node;
    }

    private static UnaryOperator<String> removeAlias(String alias) {
        return field -> FieldResolutionVisitor.hasAlias(field, alias) ? field.substring(alias.length() + 1) : field;
    }

    private static Predicate<String> excludeAlias(String alias) {
        return field -> !FieldResolutionVisitor.hasAlias(field, alias);
    }

    private Set<String> collectFieldsByAlias(Set<String> fields, String alias, String excludedAlias) {
        return fields.stream().filter(FieldResolutionVisitor.excludeAlias(excludedAlias)).map(FieldResolutionVisitor.removeAlias(alias)).collect(Collectors.toSet());
    }

    private static boolean hasAlias(String field, String alias) {
        return alias != null && field.startsWith(alias + ".");
    }

    @Override
    public Node visitSubqueryAlias(SubqueryAlias node, FieldResolutionContext context) {
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitRelation(Relation node, FieldResolutionContext context) {
        FieldResolutionResult currentReq = context.getCurrentRequirements();
        context.setResult(node, new FieldResolutionResult(currentReq.getRegularFields(), currentReq.getWildcard()));
        return node;
    }

    @Override
    public Node visitSearch(Search node, FieldResolutionContext context) {
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitRegex(Regex node, FieldResolutionContext context) {
        Set<String> regexFields = this.extractFieldsFromExpression(node.getField());
        context.pushRequirements(context.getCurrentRequirements().or(regexFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitRex(Rex node, FieldResolutionContext context) {
        Set<String> rexFields = this.extractFieldsFromExpression(node.getField());
        String patternStr = (String)node.getPattern().getValue();
        List<String> namedGroups = RegexCommonUtils.getNamedGroupCandidates(patternStr);
        context.pushRequirements(context.getCurrentRequirements().exclude(namedGroups).or(rexFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitBin(Bin node, FieldResolutionContext context) {
        Set<String> binFields = this.extractFieldsFromExpression(node.getField());
        context.pushRequirements(context.getCurrentRequirements().or(binFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitParse(Parse node, FieldResolutionContext context) {
        Set<String> parseFields = this.extractFieldsFromExpression(node.getSourceField());
        context.pushRequirements(context.getCurrentRequirements().or(parseFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitPatterns(Patterns node, FieldResolutionContext context) {
        boolean showNumberedToken;
        Set<String> patternFields = this.extractFieldsFromExpression(node.getSourceField());
        for (UnresolvedExpression partitionBy : node.getPartitionByList()) {
            patternFields.addAll(this.extractFieldsFromExpression(partitionBy));
        }
        HashSet<String> addedFields = new HashSet<String>();
        addedFields.add(node.getAlias() != null ? node.getAlias() : "patterns_field");
        if (node.getPatternMode() == PatternMode.AGGREGATION) {
            addedFields.add("pattern_count");
            addedFields.add("sample_logs");
        }
        if (node.getShowNumberedToken() != null && (showNumberedToken = Boolean.parseBoolean(node.getShowNumberedToken().toString()))) {
            addedFields.add("tokens");
        }
        context.pushRequirements(context.getCurrentRequirements().exclude(addedFields).or(patternFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitReverse(Reverse node, FieldResolutionContext context) {
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitHead(Head node, FieldResolutionContext context) {
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitRename(Rename node, FieldResolutionContext context) {
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitDedupe(Dedupe node, FieldResolutionContext context) {
        HashSet<String> dedupeFields = new HashSet<String>();
        for (Field field : node.getFields()) {
            dedupeFields.addAll(this.extractFieldsFromExpression(field));
        }
        context.pushRequirements(context.getCurrentRequirements().or(dedupeFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitWindow(Window node, FieldResolutionContext context) {
        HashSet<String> windowFields = new HashSet<String>();
        for (UnresolvedExpression windowFunc : node.getWindowFunctionList()) {
            windowFields.addAll(this.extractFieldsFromExpression(windowFunc));
        }
        if (node.getGroupList() != null) {
            for (UnresolvedExpression groupExpr : node.getGroupList()) {
                windowFields.addAll(this.extractFieldsFromExpression(groupExpr));
            }
        }
        context.pushRequirements(context.getCurrentRequirements().or(windowFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitStreamWindow(StreamWindow node, FieldResolutionContext context) {
        HashSet<String> streamWindowFields = new HashSet<String>();
        for (UnresolvedExpression windowFunc : node.getWindowFunctionList()) {
            streamWindowFields.addAll(this.extractFieldsFromExpression(windowFunc));
        }
        if (node.getGroupList() != null) {
            for (UnresolvedExpression groupExpr : node.getGroupList()) {
                streamWindowFields.addAll(this.extractFieldsFromExpression(groupExpr));
            }
        }
        if (node.getResetBefore() != null) {
            streamWindowFields.addAll(this.extractFieldsFromExpression(node.getResetBefore()));
        }
        if (node.getResetAfter() != null) {
            streamWindowFields.addAll(this.extractFieldsFromExpression(node.getResetAfter()));
        }
        context.pushRequirements(context.getCurrentRequirements().or(streamWindowFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitFillNull(FillNull node, FieldResolutionContext context) {
        if (node.isAgainstAllFields()) {
            throw new IllegalArgumentException("Fields need to be specified with fillnull command");
        }
        HashSet<String> fields = new HashSet<String>();
        node.getFields().forEach(field -> fields.addAll(this.extractFieldsFromExpression((UnresolvedExpression)field)));
        context.pushRequirements(context.getCurrentRequirements().or(fields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitAppendCol(AppendCol node, FieldResolutionContext context) {
        this.acceptAndVerifyNodeVisited(node.getSubSearch(), context);
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitAppend(Append node, FieldResolutionContext context) {
        this.acceptAndVerifyNodeVisited(node.getSubSearch(), context);
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitAppendPipe(AppendPipe node, FieldResolutionContext context) {
        this.acceptAndVerifyNodeVisited(node.getSubQuery(), context);
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitLookup(Lookup node, FieldResolutionContext context) {
        throw new IllegalArgumentException("Lookup command cannot be used together with spath command");
    }

    @Override
    public Node visitValues(Values node, FieldResolutionContext context) {
        return node;
    }

    @Override
    public Node visitMultisearch(Multisearch node, FieldResolutionContext context) {
        node.getSubsearches().forEach(subsearch -> this.acceptAndVerifyNodeVisited((Node)subsearch, context));
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitReplace(Replace node, FieldResolutionContext context) {
        HashSet<String> fields = new HashSet<String>();
        node.getFieldList().forEach(field -> fields.addAll(this.extractFieldsFromExpression((UnresolvedExpression)field)));
        context.pushRequirements(context.getCurrentRequirements().or(fields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitFlatten(Flatten node, FieldResolutionContext context) {
        Set<String> flattenFields = this.extractFieldsFromExpression(node.getField());
        context.pushRequirements(context.getCurrentRequirements().or(flattenFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitTrendline(Trendline node, FieldResolutionContext context) {
        HashSet<String> trendlineFields = new HashSet<String>();
        for (Trendline.TrendlineComputation computation : node.getComputations()) {
            trendlineFields.addAll(this.extractFieldsFromExpression(computation.getDataField()));
        }
        if (node.getSortByField().isPresent()) {
            trendlineFields.addAll(this.extractFieldsFromExpression(node.getSortByField().get()));
        }
        context.pushRequirements(context.getCurrentRequirements().or(trendlineFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitTranspose(Transpose node, FieldResolutionContext context) {
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitChart(Chart node, FieldResolutionContext context) {
        Set<String> chartFields = this.extractFieldsFromAggregation(node.getAggregationFunction());
        if (node.getRowSplit() != null) {
            chartFields.addAll(this.extractFieldsFromExpression(node.getRowSplit()));
        }
        if (node.getColumnSplit() != null) {
            chartFields.addAll(this.extractFieldsFromExpression(node.getColumnSplit()));
        }
        context.pushRequirements(new FieldResolutionResult(chartFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitRareTopN(RareTopN node, FieldResolutionContext context) {
        HashSet<String> rareTopNFields = new HashSet<String>();
        for (Field field : node.getFields()) {
            rareTopNFields.addAll(this.extractFieldsFromExpression(field));
        }
        for (UnresolvedExpression groupExpr : node.getGroupExprList()) {
            rareTopNFields.addAll(this.extractFieldsFromExpression(groupExpr));
        }
        context.pushRequirements(new FieldResolutionResult(rareTopNFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitAddTotals(AddTotals node, FieldResolutionContext context) {
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitAddColTotals(AddColTotals node, FieldResolutionContext context) {
        this.visitChildren((Node)node, context);
        return node;
    }

    @Override
    public Node visitExpand(Expand node, FieldResolutionContext context) {
        Set<String> expandFields = this.extractFieldsFromExpression(node.getField());
        context.pushRequirements(context.getCurrentRequirements().or(expandFields));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    @Override
    public Node visitMvCombine(MvCombine node, FieldResolutionContext context) {
        Set<String> mvCombineFields = this.extractFieldsFromExpression(node.getField());
        FieldResolutionResult current = context.getCurrentRequirements();
        HashSet<String> regularFields = new HashSet<String>(current.getRegularFields());
        regularFields.addAll(mvCombineFields);
        context.pushRequirements(new FieldResolutionResult(regularFields, Set.of(ALL_FIELDS)));
        this.visitChildren((Node)node, context);
        context.popRequirements();
        return node;
    }

    private Set<String> extractFieldsFromAggregation(UnresolvedExpression expr) {
        HashSet<String> fields = new HashSet<String>();
        if (expr instanceof Alias) {
            Alias alias = (Alias)expr;
            return this.extractFieldsFromAggregation(alias.getDelegated());
        }
        if (expr instanceof AggregateFunction) {
            AggregateFunction aggFunc = (AggregateFunction)expr;
            if (aggFunc.getField() != null) {
                fields.addAll(this.extractFieldsFromExpression(aggFunc.getField()));
            }
            if (aggFunc.getArgList() != null) {
                for (UnresolvedExpression arg : aggFunc.getArgList()) {
                    fields.addAll(this.extractFieldsFromExpression(arg));
                }
            }
        }
        return this.excludeAllFieldsWildcard(fields);
    }

    private Set<String> excludeAllFieldsWildcard(Set<String> fields) {
        return fields.stream().filter(f -> !f.equals(ALL_FIELDS)).collect(Collectors.toSet());
    }
}

