package org.apache.hadoop.hive.ql.exec.vector.aggregation;

import java.util.Arrays;
import java.util.List;
import junit.framework.Assert;
import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationDesc;
import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow;
import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource;
import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource;
import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx;
import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.VirtualColumn;
import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;

/* loaded from: input_file:org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase.class */
public class AggregationBase {

    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase$AggregationTestMode.class */
    public enum AggregationTestMode {
        ROW_MODE,
        VECTOR_EXPRESSION;

        static final int count = values().length;
    }

    public static GenericUDAFEvaluator getEvaluator(String str, TypeInfo typeInfo) throws SemanticException {
        return FunctionRegistry.getGenericUDAFResolver(str).getEvaluator(new TypeInfo[]{typeInfo});
    }

    protected static boolean doRowTest(TypeInfo typeInfo, GenericUDAFEvaluator genericUDAFEvaluator, TypeInfo typeInfo2, GenericUDAFEvaluator.Mode mode, int i, List<String> list, List<ExprNodeDesc> list2, Object[][] objArr, ObjectInspector objectInspector, Object[] objArr2) throws Exception {
        GenericUDAFEvaluator.AggregationBuffer[] aggregationBufferArr = new GenericUDAFEvaluator.AggregationBuffer[i + 1];
        ObjectInspector standardWritableObjectInspectorFromTypeInfo = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(typeInfo2);
        boolean countAllColumns = genericUDAFEvaluator instanceof GenericUDAFCount.GenericUDAFCountEvaluator ? ((GenericUDAFCount.GenericUDAFCountEvaluator) genericUDAFEvaluator).getCountAllColumns() : false;
        Object[] objArr3 = countAllColumns ? new Object[0] : new Object[1];
        for (Object[] objArr4 : objArr) {
            ShortWritable shortWritable = (ShortWritable) objArr4[0];
            int i2 = shortWritable == null ? i : shortWritable.get();
            GenericUDAFEvaluator.AggregationBuffer aggregationBuffer = aggregationBufferArr[i2];
            if (aggregationBuffer == null) {
                aggregationBuffer = genericUDAFEvaluator.getNewAggregationBuffer();
                aggregationBufferArr[i2] = aggregationBuffer;
            }
            if (!countAllColumns) {
                objArr3[0] = objArr4[1];
            }
            genericUDAFEvaluator.aggregate(aggregationBuffer, objArr3);
        }
        boolean z = typeInfo2 instanceof PrimitiveTypeInfo;
        boolean z2 = mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2;
        short s = 0;
        while (true) {
            short s2 = s;
            if (s2 >= i + 1) {
                return true;
            }
            GenericUDAFEvaluator.AggregationBuffer aggregationBuffer2 = aggregationBufferArr[s2];
            if (aggregationBuffer2 != null) {
                Object terminatePartial = z2 ? genericUDAFEvaluator.terminatePartial(aggregationBuffer2) : genericUDAFEvaluator.terminate(aggregationBuffer2);
                objArr2[s2] = terminatePartial == null ? null : z ? VectorRandomRowSource.getWritablePrimitiveObject((PrimitiveTypeInfo) typeInfo2, standardWritableObjectInspectorFromTypeInfo, terminatePartial) : ObjectInspectorUtils.copyToStandardObject(terminatePartial, standardWritableObjectInspectorFromTypeInfo, ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE);
            }
            s = (short) (s2 + 1);
        }
    }

    private static void extractResultObjects(VectorizedRowBatch vectorizedRowBatch, short[] sArr, VectorExtractRow vectorExtractRow, TypeInfo typeInfo, Object[] objArr, Object[] objArr2) {
        boolean z = typeInfo instanceof PrimitiveTypeInfo;
        ObjectInspector standardWritableObjectInspectorFromTypeInfo = z ? TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(typeInfo) : null;
        for (int i = 0; i < vectorizedRowBatch.size; i++) {
            vectorExtractRow.extractRow(vectorizedRowBatch, i, objArr);
            if (z) {
                objArr2[sArr[i]] = ObjectInspectorUtils.copyToStandardObject(objArr[0], standardWritableObjectInspectorFromTypeInfo, ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE);
            } else {
                objArr2[sArr[i]] = objArr[0];
            }
        }
    }

    protected static boolean doVectorTest(String str, TypeInfo typeInfo, GenericUDAFEvaluator genericUDAFEvaluator, TypeInfo typeInfo2, GenericUDAFEvaluator.Mode mode, int i, List<String> list, String[] strArr, TypeInfo[] typeInfoArr, DataTypePhysicalVariation[] dataTypePhysicalVariationArr, List<ExprNodeDesc> list2, VectorRandomBatchSource vectorRandomBatchSource, Object[] objArr) throws Exception {
        HiveConf hiveConf = new HiveConf();
        VectorizationContext vectorizationContext = new VectorizationContext("name", list, Arrays.asList(typeInfoArr), Arrays.asList(dataTypePhysicalVariationArr), hiveConf);
        VectorAggregationDesc vectorAggregationDesc = (VectorAggregationDesc) Vectorizer.getVectorAggregationDesc(str, list2, genericUDAFEvaluator, typeInfo2, mode, vectorizationContext).left;
        if (vectorAggregationDesc == null) {
            Assert.fail("No vector aggregation expression found for aggregationName " + str + " udafEvaluatorMode " + mode + " parameterList " + list2 + " outputTypeInfo " + typeInfo2);
        }
        Class vecAggrClass = vectorAggregationDesc.getVecAggrClass();
        try {
            try {
                VectorAggregateExpression vectorAggregateExpression = (VectorAggregateExpression) vecAggrClass.getConstructor(VectorAggregationDesc.class).newInstance(vectorAggregationDesc);
                VectorExpression.doTransientInit(vectorAggregateExpression.getInputExpression(), hiveConf);
                VectorRandomRowSource rowSource = vectorRandomBatchSource.getRowSource();
                VectorizedRowBatch createVectorizedRowBatch = new VectorizedRowBatchCtx(strArr, rowSource.typeInfos(), rowSource.dataTypePhysicalVariations(), (int[]) null, 0, 0, (VirtualColumn[]) null, vectorizationContext.getScratchColumnTypeNames(), vectorizationContext.getScratchDataTypePhysicalVariations()).createVectorizedRowBatch();
                VectorAggregationBufferRow[] vectorAggregationBufferRowArr = new VectorAggregationBufferRow[i + 1];
                vectorRandomBatchSource.resetBatchIteration();
                int i2 = 0;
                while (true) {
                    int i3 = i2;
                    if (!vectorRandomBatchSource.fillNextBatch(createVectorizedRowBatch)) {
                        break;
                    }
                    LongColumnVector longColumnVector = createVectorizedRowBatch.cols[0];
                    VectorAggregationBufferRow[] vectorAggregationBufferRowArr2 = new VectorAggregationBufferRow[1024];
                    int i4 = createVectorizedRowBatch.size;
                    boolean z = createVectorizedRowBatch.selectedInUse;
                    int[] iArr = createVectorizedRowBatch.selected;
                    for (int i5 = 0; i5 < i4; i5++) {
                        short s = (longColumnVector.noNulls || !longColumnVector.isNull[longColumnVector.isRepeating ? 0 : z ? iArr[i5] : i5]) ? (short) longColumnVector.vector[r44] : (short) i;
                        VectorAggregationBufferRow vectorAggregationBufferRow = vectorAggregationBufferRowArr[s];
                        if (vectorAggregationBufferRow == null) {
                            VectorAggregateExpression.AggregationBuffer newAggregationBuffer = vectorAggregateExpression.getNewAggregationBuffer();
                            newAggregationBuffer.reset();
                            vectorAggregationBufferRow = new VectorAggregationBufferRow(new VectorAggregateExpression.AggregationBuffer[]{newAggregationBuffer});
                            vectorAggregationBufferRowArr[s] = vectorAggregationBufferRow;
                        }
                        vectorAggregationBufferRowArr2[i5] = vectorAggregationBufferRow;
                    }
                    vectorAggregateExpression.aggregateInputSelection(vectorAggregationBufferRowArr2, 0, createVectorizedRowBatch);
                    i2 = i3 + createVectorizedRowBatch.size;
                }
                VectorizedRowBatch createVectorizedRowBatch2 = new VectorizedRowBatchCtx(new String[]{"output"}, new TypeInfo[]{typeInfo2}, new DataTypePhysicalVariation[]{vectorAggregateExpression.getOutputDataTypePhysicalVariation()}, (int[]) null, 0, 0, (VirtualColumn[]) null, new String[0], new DataTypePhysicalVariation[0]).createVectorizedRowBatch();
                short[] sArr = new short[1024];
                VectorExtractRow vectorExtractRow = new VectorExtractRow();
                vectorExtractRow.init(new TypeInfo[]{typeInfo2}, new int[]{0});
                Object[] objArr2 = new Object[1];
                short s2 = 0;
                while (true) {
                    short s3 = s2;
                    if (s3 >= i + 1) {
                        break;
                    }
                    VectorAggregationBufferRow vectorAggregationBufferRow2 = vectorAggregationBufferRowArr[s3];
                    if (vectorAggregationBufferRow2 != null) {
                        if (createVectorizedRowBatch2.size == 1024) {
                            extractResultObjects(createVectorizedRowBatch2, sArr, vectorExtractRow, typeInfo2, objArr2, objArr);
                            createVectorizedRowBatch2.reset();
                        }
                        sArr[createVectorizedRowBatch2.size] = s3;
                        VectorAggregateExpression.AggregationBuffer aggregationBuffer = vectorAggregationBufferRow2.getAggregationBuffer(0);
                        int i6 = createVectorizedRowBatch2.size;
                        createVectorizedRowBatch2.size = i6 + 1;
                        vectorAggregateExpression.assignRowColumn(createVectorizedRowBatch2, i6, 0, aggregationBuffer);
                    }
                    s2 = (short) (s3 + 1);
                }
                if (createVectorizedRowBatch2.size <= 0) {
                    return true;
                }
                extractResultObjects(createVectorizedRowBatch2, sArr, vectorExtractRow, typeInfo2, objArr2, objArr);
                return true;
            } catch (Exception e) {
                throw new HiveException("Failed to create " + vecAggrClass.getSimpleName() + "(VectorAggregationDesc) object ", e);
            }
        } catch (Exception e2) {
            throw new HiveException("Constructor " + vecAggrClass.getSimpleName() + "(VectorAggregationDesc) not available");
        }
    }

    private boolean compareObjects(Object obj, Object obj2, TypeInfo typeInfo, ObjectInspector objectInspector) {
        return typeInfo instanceof PrimitiveTypeInfo ? VectorRandomRowSource.getWritablePrimitiveObject((PrimitiveTypeInfo) typeInfo, objectInspector, obj).equals(VectorRandomRowSource.getWritablePrimitiveObject((PrimitiveTypeInfo) typeInfo, objectInspector, obj2)) : obj.equals(obj2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void executeAggregationTests(String str, TypeInfo typeInfo, GenericUDAFEvaluator genericUDAFEvaluator, TypeInfo typeInfo2, GenericUDAFEvaluator.Mode mode, int i, List<String> list, String[] strArr, List<ExprNodeDesc> list2, Object[][] objArr, VectorRandomRowSource vectorRandomRowSource, VectorRandomBatchSource vectorRandomBatchSource, boolean z, Object[] objArr2) throws Exception {
        for (int i2 = 0; i2 < AggregationTestMode.count; i2++) {
            Object[] objArr3 = new Object[i + 1];
            objArr2[i2] = objArr3;
            AggregationTestMode aggregationTestMode = AggregationTestMode.values()[i2];
            switch (aggregationTestMode) {
                case ROW_MODE:
                    if (doRowTest(typeInfo, genericUDAFEvaluator, typeInfo2, mode, i, list, list2, objArr, vectorRandomRowSource.rowStructObjectInspector(), objArr3)) {
                        break;
                    } else {
                        return;
                    }
                case VECTOR_EXPRESSION:
                    if (doVectorTest(str, typeInfo, genericUDAFEvaluator, typeInfo2, mode, i, list, strArr, vectorRandomRowSource.typeInfos(), vectorRandomRowSource.dataTypePhysicalVariations(), list2, vectorRandomBatchSource, objArr3)) {
                        break;
                    } else {
                        return;
                    }
                default:
                    throw new RuntimeException("Unexpected Hash Aggregation test mode " + aggregationTestMode);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void verifyAggregationResults(TypeInfo typeInfo, TypeInfo typeInfo2, int i, GenericUDAFEvaluator.Mode mode, Object[] objArr) {
        Object[] objArr2 = (Object[]) objArr[0];
        ObjectInspector standardWritableObjectInspectorFromTypeInfo = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(typeInfo2);
        for (int i2 = 1; i2 < AggregationTestMode.count; i2++) {
            Object[] objArr3 = (Object[]) objArr[i2];
            short s = 0;
            while (true) {
                short s2 = s;
                if (s2 < i + 1) {
                    Object obj = objArr2[s2];
                    Object obj2 = objArr3[s2];
                    if (obj == null || obj2 == null) {
                        if (obj != null || obj2 != null) {
                            Assert.fail("Key " + ((int) s2) + " typeName " + typeInfo.getTypeName() + " outputTypeName " + typeInfo2.getTypeName() + " " + AggregationTestMode.values()[i2] + " result is NULL " + (obj2 == null ? "YES" : "NO result " + obj2.toString()) + " does not match row-mode expected result is NULL " + (obj == null ? "YES" : "NO result " + obj.toString()) + " udafEvaluatorMode " + mode);
                        }
                    } else if (!compareObjects(obj, obj2, typeInfo2, standardWritableObjectInspectorFromTypeInfo)) {
                        Assert.fail("Key " + ((int) s2) + " typeName " + typeInfo.getTypeName() + " outputTypeName " + typeInfo2.getTypeName() + " " + AggregationTestMode.values()[i2] + " result " + obj2.toString() + " (" + obj2.getClass().getSimpleName() + ") does not match row-mode expected result " + obj.toString() + " (" + obj.getClass().getSimpleName() + ") udafEvaluatorMode " + mode);
                    }
                    s = (short) (s2 + 1);
                }
            }
        }
    }
}
