package org.apache.hadoop.hive.ql.exec.vector.aggregation;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import junit.framework.Assert;
import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation;
import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource;
import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource;
import org.apache.hadoop.hive.ql.exec.vector.aggregation.AggregationBase;
import org.apache.hadoop.hive.ql.io.protobuf.SampleProtos;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableShortObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo;
import org.junit.Ignore;
import org.junit.Test;

/* loaded from: input_file:org/apache/hadoop/hive/ql/exec/vector/aggregation/TestVectorAggregation.class */
public class TestVectorAggregation extends AggregationBase {
    private static final Set<String> varianceNames = new HashSet();
    private static TypeInfo[] integerTypeInfos;
    private static TypeInfo[] floatingTypeInfos;
    private static TypeInfo[] decimalTypeInfos;
    private static TypeInfo[] stringFamilyTypeInfos;
    private static final int TEST_ROW_COUNT = 100000;

    @Test
    public void testAvgIntegers() throws Exception {
        doIntegerTests("avg", new Random(7743L));
    }

    @Test
    public void testAvgFloating() throws Exception {
        doFloatingTests("avg", new Random(7743L));
    }

    @Test
    public void testAvgDecimal() throws Exception {
        doDecimalTests("avg", new Random(7743L), false);
    }

    @Test
    public void testAvgDecimal64() throws Exception {
        doDecimalTests("avg", new Random(7743L), true);
    }

    @Test
    public void testAvgTimestamp() throws Exception {
        doTests(new Random(7743L), "avg", TypeInfoFactory.timestampTypeInfo);
    }

    @Test
    public void testCount() throws Exception {
        Random random = new Random(7743L);
        doTests(random, "count", TypeInfoFactory.shortTypeInfo);
        doTests(random, "count", TypeInfoFactory.longTypeInfo);
        doTests(random, "count", TypeInfoFactory.doubleTypeInfo);
        doTests(random, "count", new DecimalTypeInfo(18, 10));
        doTests(random, "count", TypeInfoFactory.stringTypeInfo);
    }

    @Test
    public void testCountStar() throws Exception {
        Random random = new Random(7743L);
        doTests(random, "count", TypeInfoFactory.shortTypeInfo, true, false);
        doTests(random, "count", TypeInfoFactory.longTypeInfo, true, false);
        doTests(random, "count", TypeInfoFactory.doubleTypeInfo, true, false);
        doTests(random, "count", new DecimalTypeInfo(18, 10), true, false);
        doTests(random, "count", TypeInfoFactory.stringTypeInfo, true, false);
    }

    @Test
    public void testMax() throws Exception {
        Random random = new Random(7743L);
        doIntegerTests("max", random);
        doFloatingTests("max", random);
        doDecimalTests("max", random, false);
        doDecimalTests("max", random, true);
        doTests(random, "max", TypeInfoFactory.timestampTypeInfo);
        doTests(random, "max", TypeInfoFactory.intervalDayTimeTypeInfo);
        doStringFamilyTests("max", random);
    }

    @Test
    public void testMin() throws Exception {
        Random random = new Random(7743L);
        doIntegerTests("min", random);
        doFloatingTests("min", random);
        doDecimalTests("min", random, false);
        doDecimalTests("min", random, true);
        doTests(random, "min", TypeInfoFactory.timestampTypeInfo);
        doTests(random, "min", TypeInfoFactory.intervalDayTimeTypeInfo);
        doStringFamilyTests("min", random);
    }

    @Test
    public void testSum() throws Exception {
        Random random = new Random(7743L);
        doTests(random, "sum", TypeInfoFactory.shortTypeInfo);
        doTests(random, "sum", TypeInfoFactory.longTypeInfo);
        doTests(random, "sum", TypeInfoFactory.doubleTypeInfo);
        doDecimalTests("sum", random, false);
        doDecimalTests("sum", random, true);
        doTests(random, "sum", TypeInfoFactory.timestampTypeInfo);
    }

    @Test
    @Ignore
    public void testBloomFilter() throws Exception {
        Random random = new Random(7743L);
        doIntegerTests("bloom_filter", random);
        doFloatingTests("bloom_filter", random);
        doDecimalTests("bloom_filter", random, false);
        doTests(random, "bloom_filter", TypeInfoFactory.timestampTypeInfo);
        doStringFamilyTests("bloom_filter", random);
    }

    @Test
    public void testVarianceIntegers() throws Exception {
        Random random = new Random(7743L);
        Iterator<String> it = varianceNames.iterator();
        while (it.hasNext()) {
            doIntegerTests(it.next(), random);
        }
    }

    @Test
    public void testVarianceFloating() throws Exception {
        Random random = new Random(7743L);
        Iterator<String> it = varianceNames.iterator();
        while (it.hasNext()) {
            doFloatingTests(it.next(), random);
        }
    }

    @Test
    public void testVarianceDecimal() throws Exception {
        Random random = new Random(7743L);
        Iterator<String> it = varianceNames.iterator();
        while (it.hasNext()) {
            doDecimalTests(it.next(), random, false);
        }
    }

    @Test
    public void testVarianceTimestamp() throws Exception {
        Random random = new Random(7743L);
        Iterator<String> it = varianceNames.iterator();
        while (it.hasNext()) {
            doTests(random, it.next(), TypeInfoFactory.timestampTypeInfo);
        }
    }

    private void doIntegerTests(String str, Random random) throws Exception {
        for (TypeInfo typeInfo : integerTypeInfos) {
            doTests(random, str, typeInfo);
        }
    }

    private void doFloatingTests(String str, Random random) throws Exception {
        for (TypeInfo typeInfo : floatingTypeInfos) {
            doTests(random, str, typeInfo);
        }
    }

    private void doDecimalTests(String str, Random random, boolean z) throws Exception {
        for (TypeInfo typeInfo : decimalTypeInfos) {
            doTests(random, str, typeInfo, false, z);
        }
    }

    private void doStringFamilyTests(String str, Random random) throws Exception {
        for (TypeInfo typeInfo : stringFamilyTypeInfos) {
            doTests(random, str, typeInfo);
        }
    }

    private boolean checkDecimal64(boolean z, TypeInfo typeInfo) {
        if (z && (typeInfo instanceof DecimalTypeInfo)) {
            return HiveDecimalWritable.isPrecisionDecimal64(((DecimalTypeInfo) typeInfo).getPrecision());
        }
        return false;
    }

    public static int getLinearRandomNumber(Random random, int i) {
        int nextInt = random.nextInt((i * (i + 1)) / 2);
        int i2 = 0;
        int i3 = i;
        while (nextInt >= 0) {
            nextInt -= i3;
            i2++;
            i3--;
        }
        return i2;
    }

    private void doMerge(GenericUDAFEvaluator.Mode mode, Random random, String str, TypeInfo typeInfo, VectorRandomRowSource.GenerationSpec generationSpec, List<String> list, String[] strArr, int i, int i2, TypeInfo typeInfo2, Object[] objArr) throws Exception {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(generationSpec);
        arrayList2.add(DataTypePhysicalVariation.NONE);
        arrayList.add(VectorRandomRowSource.GenerationSpec.createOmitGeneration(typeInfo2));
        arrayList2.add(DataTypePhysicalVariation.NONE);
        ExprNodeColumnDesc exprNodeColumnDesc = new ExprNodeColumnDesc(typeInfo2, "col1", "table", false);
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(exprNodeColumnDesc);
        int size = arrayList3.size();
        ObjectInspector[] objectInspectorArr = new ObjectInspector[size];
        for (int i3 = 0; i3 < size; i3++) {
            objectInspectorArr[i3] = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(arrayList3.get(i3).getTypeInfo());
        }
        VectorRandomRowSource vectorRandomRowSource = new VectorRandomRowSource();
        vectorRandomRowSource.initGenerationSpecSchema(random, arrayList, 0, false, true, arrayList2);
        Object[][] randomRows = vectorRandomRowSource.randomRows(TEST_ROW_COUNT);
        int i4 = i / i2;
        Object[] objArr2 = (Object[]) objArr[0];
        short s = 0;
        for (int i5 = 0; i5 < randomRows.length; i5++) {
            while (true) {
                if (s >= i) {
                    s = 0;
                }
                if (objArr2[s] != null) {
                    break;
                } else {
                    s = (short) (s + 1);
                }
            }
            randomRows[i5][0] = new ShortWritable((short) (s % i4));
            randomRows[i5][1] = objArr2[s];
            s = (short) (s + 1);
        }
        VectorRandomBatchSource createInterestingBatches = VectorRandomBatchSource.createInterestingBatches(random, vectorRandomRowSource, randomRows, null);
        GenericUDAFEvaluator evaluator = getEvaluator(str, typeInfo);
        TypeInfo typeInfoFromObjectInspector = TypeInfoUtils.getTypeInfoFromObjectInspector(evaluator.init(mode, objectInspectorArr));
        Object[] objArr3 = new Object[AggregationBase.AggregationTestMode.count];
        executeAggregationTests(str, typeInfo2, evaluator, typeInfoFromObjectInspector, mode, i4, list, strArr, arrayList3, randomRows, vectorRandomRowSource, createInterestingBatches, false, objArr3);
        verifyAggregationResults(typeInfo2, typeInfoFromObjectInspector, i4, mode, objArr3);
    }

    private void doTests(Random random, String str, TypeInfo typeInfo) throws Exception {
        doTests(random, str, typeInfo, false, false);
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:123:0x028f. Please report as an issue. */
    /* JADX WARN: Failed to find 'out' block for switch in B:70:0x0567. Please report as an issue. */
    private void doTests(Random random, String str, TypeInfo typeInfo, boolean z, boolean z2) throws Exception {
        boolean z3;
        boolean z4;
        boolean z5;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        PrimitiveTypeInfo primitiveTypeInfo = TypeInfoFactory.shortTypeInfo;
        VectorRandomRowSource.GenerationSpec createOmitGeneration = VectorRandomRowSource.GenerationSpec.createOmitGeneration(primitiveTypeInfo);
        arrayList.add(createOmitGeneration);
        arrayList2.add(DataTypePhysicalVariation.NONE);
        boolean checkDecimal64 = checkDecimal64(z2, typeInfo);
        arrayList.add(VectorRandomRowSource.GenerationSpec.createSameType(typeInfo));
        arrayList2.add(checkDecimal64 ? DataTypePhysicalVariation.DECIMAL_64 : DataTypePhysicalVariation.NONE);
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add("col0");
        arrayList3.add("col1");
        ExprNodeColumnDesc exprNodeColumnDesc = new ExprNodeColumnDesc(typeInfo, "col1", "table", false);
        ArrayList arrayList4 = new ArrayList();
        if (!z) {
            arrayList4.add(exprNodeColumnDesc);
        }
        int size = arrayList4.size();
        ObjectInspector[] objectInspectorArr = new ObjectInspector[size];
        for (int i = 0; i < size; i++) {
            objectInspectorArr[i] = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(arrayList4.get(i).getTypeInfo());
        }
        String[] strArr = (String[]) arrayList3.toArray(new String[0]);
        WritableShortObjectInspector objectInspector = VectorRandomRowSource.getObjectInspector(primitiveTypeInfo);
        VectorRandomRowSource vectorRandomRowSource = new VectorRandomRowSource();
        vectorRandomRowSource.initGenerationSpecSchema(random, arrayList, 0, !str.equals("bloom_filter"), true, arrayList2);
        Object[][] randomRows = vectorRandomRowSource.randomRows(TEST_ROW_COUNT);
        for (Object[] objArr : randomRows) {
            objArr[0] = objectInspector.create((short) getLinearRandomNumber(random, 20000));
        }
        VectorRandomBatchSource createInterestingBatches = VectorRandomBatchSource.createInterestingBatches(random, vectorRandomRowSource, randomRows, null);
        GenericUDAFCount.GenericUDAFCountEvaluator evaluator = getEvaluator(str, typeInfo);
        if (z) {
            Assert.assertTrue(evaluator instanceof GenericUDAFCount.GenericUDAFCountEvaluator);
            evaluator.setCountAllColumns(true);
        }
        GenericUDAFEvaluator.Mode mode = GenericUDAFEvaluator.Mode.PARTIAL1;
        TypeInfo typeInfoFromObjectInspector = TypeInfoUtils.getTypeInfoFromObjectInspector(evaluator.init(mode, objectInspectorArr));
        Object[] objArr2 = new Object[AggregationBase.AggregationTestMode.count];
        executeAggregationTests(str, typeInfo, evaluator, typeInfoFromObjectInspector, mode, 20000, arrayList3, strArr, arrayList4, randomRows, vectorRandomRowSource, createInterestingBatches, z2, objArr2);
        verifyAggregationResults(typeInfo, typeInfoFromObjectInspector, 20000, mode, objArr2);
        if (varianceNames.contains(str)) {
            z3 = true;
        } else {
            boolean z6 = -1;
            switch (str.hashCode()) {
                case -2007381708:
                    if (str.equals("bloom_filter")) {
                        z6 = true;
                        break;
                    }
                    break;
                case 96978:
                    if (str.equals("avg")) {
                        z6 = false;
                        break;
                    }
                    break;
                case 107876:
                    if (str.equals("max")) {
                        z6 = 3;
                        break;
                    }
                    break;
                case 108114:
                    if (str.equals("min")) {
                        z6 = 4;
                        break;
                    }
                    break;
                case 114251:
                    if (str.equals("sum")) {
                        z6 = 5;
                        break;
                    }
                    break;
                case 94851343:
                    if (str.equals("count")) {
                        z6 = 2;
                        break;
                    }
                    break;
            }
            switch (z6) {
                case false:
                    z3 = true;
                    break;
                case true:
                case true:
                case true:
                case SampleProtos.AllTypes.INT64TYPE_FIELD_NUMBER /* 4 */:
                case SampleProtos.AllTypes.UINT32TYPE_FIELD_NUMBER /* 5 */:
                    z3 = false;
                    break;
                default:
                    throw new RuntimeException("Unexpected aggregation name " + str);
            }
        }
        if (z3) {
            VectorRandomRowSource vectorRandomRowSource2 = new VectorRandomRowSource();
            vectorRandomRowSource2.initGenerationSpecSchema(random, arrayList, 0, true, true, arrayList2);
            Object[][] randomRows2 = vectorRandomRowSource2.randomRows(TEST_ROW_COUNT);
            for (Object[] objArr3 : randomRows2) {
                objArr3[0] = objectInspector.create((short) getLinearRandomNumber(random, 20000));
            }
            VectorRandomBatchSource createInterestingBatches2 = VectorRandomBatchSource.createInterestingBatches(random, vectorRandomRowSource2, randomRows2, null);
            GenericUDAFEvaluator evaluator2 = getEvaluator(str, typeInfo);
            GenericUDAFEvaluator.Mode mode2 = GenericUDAFEvaluator.Mode.COMPLETE;
            TypeInfo typeInfoFromObjectInspector2 = TypeInfoUtils.getTypeInfoFromObjectInspector(evaluator2.init(mode2, objectInspectorArr));
            Object[] objArr4 = new Object[AggregationBase.AggregationTestMode.count];
            executeAggregationTests(str, typeInfo, evaluator2, typeInfoFromObjectInspector2, mode2, 20000, arrayList3, strArr, arrayList4, randomRows2, vectorRandomRowSource2, createInterestingBatches2, z2, objArr4);
            verifyAggregationResults(typeInfo, typeInfoFromObjectInspector2, 20000, mode2, objArr4);
        }
        if (varianceNames.contains(str)) {
            z4 = true;
        } else {
            boolean z7 = -1;
            switch (str.hashCode()) {
                case -2007381708:
                    if (str.equals("bloom_filter")) {
                        z7 = true;
                        break;
                    }
                    break;
                case 96978:
                    if (str.equals("avg")) {
                        z7 = false;
                        break;
                    }
                    break;
                case 107876:
                    if (str.equals("max")) {
                        z7 = 3;
                        break;
                    }
                    break;
                case 108114:
                    if (str.equals("min")) {
                        z7 = 4;
                        break;
                    }
                    break;
                case 114251:
                    if (str.equals("sum")) {
                        z7 = 5;
                        break;
                    }
                    break;
                case 94851343:
                    if (str.equals("count")) {
                        z7 = 2;
                        break;
                    }
                    break;
            }
            switch (z7) {
                case false:
                    z4 = true;
                    break;
                case true:
                case true:
                case true:
                case SampleProtos.AllTypes.INT64TYPE_FIELD_NUMBER /* 4 */:
                case SampleProtos.AllTypes.UINT32TYPE_FIELD_NUMBER /* 5 */:
                    z4 = false;
                    break;
                default:
                    throw new RuntimeException("Unexpected aggregation name " + str);
            }
        }
        if (z4) {
            doMerge(GenericUDAFEvaluator.Mode.PARTIAL2, random, str, typeInfo, createOmitGeneration, arrayList3, strArr, 20000, 16, typeInfoFromObjectInspector, objArr2);
        }
        if (varianceNames.contains(str)) {
            z5 = true;
        } else {
            boolean z8 = -1;
            switch (str.hashCode()) {
                case -2007381708:
                    if (str.equals("bloom_filter")) {
                        z8 = true;
                        break;
                    }
                    break;
                case 96978:
                    if (str.equals("avg")) {
                        z8 = false;
                        break;
                    }
                    break;
                case 107876:
                    if (str.equals("max")) {
                        z8 = 3;
                        break;
                    }
                    break;
                case 108114:
                    if (str.equals("min")) {
                        z8 = 4;
                        break;
                    }
                    break;
                case 114251:
                    if (str.equals("sum")) {
                        z8 = 5;
                        break;
                    }
                    break;
                case 94851343:
                    if (str.equals("count")) {
                        z8 = 2;
                        break;
                    }
                    break;
            }
            switch (z8) {
                case false:
                    z5 = true;
                    break;
                case true:
                case true:
                    z5 = true;
                    break;
                case true:
                case SampleProtos.AllTypes.INT64TYPE_FIELD_NUMBER /* 4 */:
                case SampleProtos.AllTypes.UINT32TYPE_FIELD_NUMBER /* 5 */:
                    z5 = false;
                    break;
                default:
                    throw new RuntimeException("Unexpected aggregation name " + str);
            }
        }
        if (z5) {
            doMerge(GenericUDAFEvaluator.Mode.FINAL, random, str, typeInfo, createOmitGeneration, arrayList3, strArr, 20000, 16, typeInfoFromObjectInspector, objArr2);
        }
    }

    static {
        varianceNames.add("variance");
        varianceNames.add("var_samp");
        varianceNames.add("std");
        varianceNames.add("stddev_samp");
        integerTypeInfos = new TypeInfo[]{TypeInfoFactory.byteTypeInfo, TypeInfoFactory.shortTypeInfo, TypeInfoFactory.intTypeInfo, TypeInfoFactory.longTypeInfo};
        floatingTypeInfos = new TypeInfo[]{TypeInfoFactory.doubleTypeInfo};
        decimalTypeInfos = new TypeInfo[]{new DecimalTypeInfo(38, 18), new DecimalTypeInfo(25, 2), new DecimalTypeInfo(19, 4), new DecimalTypeInfo(18, 10), new DecimalTypeInfo(17, 3), new DecimalTypeInfo(12, 2), new DecimalTypeInfo(7, 1)};
        stringFamilyTypeInfos = new TypeInfo[]{TypeInfoFactory.stringTypeInfo, new CharTypeInfo(25), new CharTypeInfo(10), new VarcharTypeInfo(20), new VarcharTypeInfo(15)};
    }
}
