package org.apache.spark.examples.ml;

import java.util.Arrays;
import org.apache.commons.lang.time.DateUtils;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

/* loaded from: input_file:org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.class */
public class JavaModelSelectionViaCrossValidationExample {
    public static void main(String[] strArr) {
        SparkSession orCreate = SparkSession.builder().appName("JavaModelSelectionViaCrossValidationExample").getOrCreate();
        Dataset createDataFrame = orCreate.createDataFrame(Arrays.asList(new JavaLabeledDocument(0L, "a b c d e spark", 1.0d), new JavaLabeledDocument(1L, "b d", 0.0d), new JavaLabeledDocument(2L, "spark f g h", 1.0d), new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0d), new JavaLabeledDocument(4L, "b spark who", 1.0d), new JavaLabeledDocument(5L, "g d a y", 0.0d), new JavaLabeledDocument(6L, "spark fly", 1.0d), new JavaLabeledDocument(7L, "was mapreduce", 0.0d), new JavaLabeledDocument(8L, "e spark program", 1.0d), new JavaLabeledDocument(9L, "a e c l", 0.0d), new JavaLabeledDocument(10L, "spark compile", 1.0d), new JavaLabeledDocument(11L, "hadoop software", 0.0d)), JavaLabeledDocument.class);
        PipelineStage pipelineStage = (Tokenizer) new Tokenizer().setInputCol("text").setOutputCol("words");
        PipelineStage outputCol = new HashingTF().setNumFeatures(DateUtils.MILLIS_IN_SECOND).setInputCol(pipelineStage.getOutputCol()).setOutputCol("features");
        PipelineStage regParam = new LogisticRegression().setMaxIter(10).setRegParam(0.01d);
        for (Row row : new CrossValidator().setEstimator(new Pipeline().setStages(new PipelineStage[]{pipelineStage, outputCol, regParam})).setEvaluator(new BinaryClassificationEvaluator()).setEstimatorParamMaps(new ParamGridBuilder().addGrid(outputCol.numFeatures(), new int[]{10, 100, DateUtils.MILLIS_IN_SECOND}).addGrid(regParam.regParam(), new double[]{0.1d, 0.01d}).build()).setNumFolds(2).fit(createDataFrame).transform(orCreate.createDataFrame(Arrays.asList(new JavaDocument(4L, "spark i j k"), new JavaDocument(5L, "l m n"), new JavaDocument(6L, "mapreduce spark"), new JavaDocument(7L, "apache hadoop")), JavaDocument.class)).select("id", new String[]{"text", "probability", "prediction"}).collectAsList()) {
            System.out.println("(" + row.get(0) + ", " + row.get(1) + ") --> prob=" + row.get(2) + ", prediction=" + row.get(3));
        }
        orCreate.stop();
    }
}
