/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.rules.ImmutableAggregateGroupingSetsToUnionRule;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;

@Value.Enclosing
public class AggregateGroupingSetsToUnionRule
extends RelRule<Config>
implements SubstitutionRule {
    protected AggregateGroupingSetsToUnionRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        if (Aggregate.isSimple(aggregate)) {
            return;
        }
        RelBuilder relBuilder = call.builder();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RelNode input = aggregate.getInput();
        RelDataType rowType = aggregate.getRowType();
        ImmutableBitSet oriGroupSet = aggregate.getGroupSet();
        ArrayList<RelNode> unionInputs = new ArrayList<RelNode>();
        for (ImmutableBitSet subGroupSet : aggregate.getGroupSets()) {
            relBuilder.push(input);
            ArrayList<RexNode> subProjects = new ArrayList<RexNode>();
            RelDataType subAggregateType = Aggregate.deriveRowType(relBuilder.getTypeFactory(), relBuilder.peek().getRowType(), false, subGroupSet, (List<ImmutableBitSet>)ImmutableList.of((Object)subGroupSet), (List<AggregateCall>)ImmutableList.of());
            for (int i = 0; i < oriGroupSet.cardinality(); ++i) {
                int groupKey = oriGroupSet.nth(i);
                if (subGroupSet.get(groupKey)) {
                    subProjects.add(RexInputRef.of(subGroupSet.indexOf(groupKey), subAggregateType));
                    continue;
                }
                subProjects.add(rexBuilder.makeNullLiteral(relBuilder.field(groupKey).getType()));
            }
            ArrayList<AggregateCall> subAggCalls = new ArrayList<AggregateCall>();
            block7: for (AggregateCall aggCall : aggregate.getAggCallList()) {
                switch (aggCall.getAggregation().getKind()) {
                    case GROUPING: {
                        int groupingValue = AggregateGroupingSetsToUnionRule.evaluateGroupingFunction(subGroupSet, aggCall.getArgList());
                        subProjects.add(rexBuilder.makeLiteral(groupingValue, aggCall.getType(), true));
                        continue block7;
                    }
                    case GROUP_ID: {
                        return;
                    }
                    case GROUPING_ID: {
                        return;
                    }
                }
                subProjects.add(new RexInputRef(subGroupSet.cardinality() + subAggCalls.size(), aggCall.getType()));
                subAggCalls.add(aggCall);
            }
            relBuilder.aggregate(relBuilder.groupKey(subGroupSet), (List<AggregateCall>)subAggCalls).project(subProjects, rowType.getFieldNames());
            unionInputs.add(relBuilder.build());
        }
        relBuilder.pushAll(unionInputs).union(true, unionInputs.size());
        call.transformTo(relBuilder.build());
    }

    private static int evaluateGroupingFunction(ImmutableBitSet groupSet, List<Integer> argIndices) {
        int argCount = argIndices.size();
        if (argCount >= 32) {
            throw new IllegalArgumentException("Too many grouping keys. Maximum is 31 for grouping functions.");
        }
        int result = 0;
        for (int k = 0; k < argCount; ++k) {
            int index = argIndices.get(argCount - 1 - k);
            if (groupSet.get(index)) continue;
            result |= 1 << k;
        }
        return result;
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableAggregateGroupingSetsToUnionRule.Config.of().withOperandFor(Aggregate.class, Values.class);

        @Override
        default public AggregateGroupingSetsToUnionRule toRule() {
            return new AggregateGroupingSetsToUnionRule(this);
        }

        default public Config withOperandFor(Class<? extends Aggregate> aggregateClass, Class<? extends Values> valuesClass) {
            return this.withOperandSupplier(b0 -> b0.operand(aggregateClass).predicate(aggregate -> aggregate.getGroupType() != Aggregate.Group.SIMPLE).anyInputs()).as(Config.class);
        }
    }
}

