package org.apache.spark.ml.classification;

import java.io.IOException;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.ann.FeedForwardTopology$;
import org.apache.spark.ml.ann.FeedForwardTrainer;
import org.apache.spark.ml.classification.MultilayerPerceptronParams;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntArrayParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasMaxIter;
import org.apache.spark.ml.param.shared.HasSeed;
import org.apache.spark.ml.param.shared.HasSolver;
import org.apache.spark.ml.param.shared.HasStepSize;
import org.apache.spark.ml.param.shared.HasTol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.Instrumentation;
import org.apache.spark.ml.util.Instrumentation$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import scala.Function1;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: MultilayerPerceptronClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Ef\u0001B\u0001\u0003\u00015\u0011a$T;mi&d\u0017-_3s!\u0016\u00148-\u001a9ue>t7\t\\1tg&4\u0017.\u001a:\u000b\u0005\r!\u0011AD2mCN\u001c\u0018NZ5dCRLwN\u001c\u0006\u0003\u000b\u0019\t!!\u001c7\u000b\u0005\u001dA\u0011!B:qCJ\\'BA\u0005\u000b\u0003\u0019\t\u0007/Y2iK*\t1\"A\u0002pe\u001e\u001c\u0001a\u0005\u0003\u0001\u001du\u0001\u0003#B\b\u0011%aQR\"\u0001\u0003\n\u0005E!!!\u0003)sK\u0012L7\r^8s!\t\u0019b#D\u0001\u0015\u0015\t)B!\u0001\u0004mS:\fGnZ\u0005\u0003/Q\u0011aAV3di>\u0014\bCA\r\u0001\u001b\u0005\u0011\u0001CA\r\u001c\u0013\ta\"AA\u0014Nk2$\u0018\u000e\\1zKJ\u0004VM]2faR\u0014xN\\\"mCN\u001c\u0018NZ5dCRLwN\\'pI\u0016d\u0007CA\r\u001f\u0013\ty\"A\u0001\u000eNk2$\u0018\u000e\\1zKJ\u0004VM]2faR\u0014xN\u001c)be\u0006l7\u000f\u0005\u0002\"I5\t!E\u0003\u0002$\t\u0005!Q\u000f^5m\u0013\t)#EA\u000bEK\u001a\fW\u000f\u001c;QCJ\fWn],sSR\f'\r\\3\t\u0011\u001d\u0002!Q1A\u0005B!\n1!^5e+\u0005I\u0003C\u0001\u00161\u001d\tYc&D\u0001-\u0015\u0005i\u0013!B:dC2\f\u0017BA\u0018-\u0003\u0019\u0001&/\u001a3fM&\u0011\u0011G\r\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005=b\u0003f\u0001\u00145uA\u0011Q\u0007O\u0007\u0002m)\u0011qGB\u0001\u000bC:tw\u000e^1uS>t\u0017BA\u001d7\u0005\u0015\u0019\u0016N\\2fC\u0005Y\u0014!B\u0019/k9\u0002\u0004\u0002C\u001f\u0001\u0005\u0003\u0005\u000b\u0011B\u0015\u0002\tULG\r\t\u0015\u0004yQR\u0004\"\u0002!\u0001\t\u0003\t\u0015A\u0002\u001fj]&$h\b\u0006\u0002\u0019\u0005\")qe\u0010a\u0001S!\u001a!\t\u000e\u001e)\u0007}\"$\bC\u0003A\u0001\u0011\u0005a\tF\u0001\u0019Q\r)EG\u000f\u0005\u0006\u0013\u0002!\tAS\u0001\ng\u0016$H*Y=feN$\"a\u0013'\u000e\u0003\u0001AQ!\u0014%A\u00029\u000bQA^1mk\u0016\u00042aK(R\u0013\t\u0001FFA\u0003BeJ\f\u0017\u0010\u0005\u0002,%&\u00111\u000b\f\u0002\u0004\u0013:$\bf\u0001%5u!)a\u000b\u0001C\u0001/\u0006a1/\u001a;CY>\u001c7nU5{KR\u00111\n\u0017\u0005\u0006\u001bV\u0003\r!\u0015\u0015\u0004+RR\u0004\"B.\u0001\t\u0003a\u0016!C:fiN{GN^3s)\tYU\fC\u0003N5\u0002\u0007\u0011\u0006K\u0002[i}\u000b\u0013\u0001Y\u0001\u0006e9\u0002d\u0006\r\u0005\u0006E\u0002!\taY\u0001\u000bg\u0016$X*\u0019=Ji\u0016\u0014HCA&e\u0011\u0015i\u0015\r1\u0001RQ\r\tGG\u000f\u0005\u0006O\u0002!\t\u0001[\u0001\u0007g\u0016$Hk\u001c7\u0015\u0005-K\u0007\"B'g\u0001\u0004Q\u0007CA\u0016l\u0013\taGF\u0001\u0004E_V\u0014G.\u001a\u0015\u0004MRR\u0004\"B8\u0001\t\u0003\u0001\u0018aB:fiN+W\r\u001a\u000b\u0003\u0017FDQ!\u00148A\u0002I\u0004\"aK:\n\u0005Qd#\u0001\u0002'p]\u001eD3A\u001c\u001b;\u0011\u00159\b\u0001\"\u0001y\u0003E\u0019X\r^%oSRL\u0017\r\\,fS\u001eDGo\u001d\u000b\u0003\u0017fDQ!\u0014<A\u0002IA3A\u001e\u001b`\u0011\u0015a\b\u0001\"\u0001~\u0003-\u0019X\r^*uKB\u001c\u0016N_3\u0015\u0005-s\b\"B'|\u0001\u0004Q\u0007fA>5?\"9\u00111\u0001\u0001\u0005B\u0005\u0015\u0011\u0001B2paf$2\u0001GA\u0004\u0011!\tI!!\u0001A\u0002\u0005-\u0011!B3yiJ\f\u0007\u0003BA\u0007\u0003'i!!a\u0004\u000b\u0007\u0005EA!A\u0003qCJ\fW.\u0003\u0003\u0002\u0016\u0005=!\u0001\u0003)be\u0006lW*\u00199)\t\u0005\u0005AG\u000f\u0005\b\u00037\u0001A\u0011KA\u000f\u0003\u0015!(/Y5o)\rQ\u0012q\u0004\u0005\t\u0003C\tI\u00021\u0001\u0002$\u00059A-\u0019;bg\u0016$\b\u0007BA\u0013\u0003k\u0001b!a\n\u0002.\u0005ERBAA\u0015\u0015\r\tYCB\u0001\u0004gFd\u0017\u0002BA\u0018\u0003S\u0011q\u0001R1uCN,G\u000f\u0005\u0003\u00024\u0005UB\u0002\u0001\u0003\r\u0003o\ty\"!A\u0001\u0002\u000b\u0005\u0011\u0011\b\u0002\u0004?\u0012\n\u0014\u0003BA\u001e\u0003\u0003\u00022aKA\u001f\u0013\r\ty\u0004\f\u0002\b\u001d>$\b.\u001b8h!\rY\u00131I\u0005\u0004\u0003\u000bb#aA!os\"\u001a\u0001\u0001\u000e\u001e\b\u000f\u0005-#\u0001#\u0001\u0002N\u0005qR*\u001e7uS2\f\u00170\u001a:QKJ\u001cW\r\u001d;s_:\u001cE.Y:tS\u001aLWM\u001d\t\u00043\u0005=cAB\u0001\u0003\u0011\u0003\t\tf\u0005\u0005\u0002P\u0005M\u0013\u0011LA0!\rY\u0013QK\u0005\u0004\u0003/b#AB!osJ+g\r\u0005\u0003\"\u00037B\u0012bAA/E\t)B)\u001a4bk2$\b+\u0019:b[N\u0014V-\u00193bE2,\u0007cA\u0016\u0002b%\u0019\u00111\r\u0017\u0003\u0019M+'/[1mSj\f'\r\\3\t\u000f\u0001\u000by\u0005\"\u0001\u0002hQ\u0011\u0011Q\n\u0005\f\u0003W\nyE1A\u0005\u0002\t\ti'A\u0003M\u0005\u001a;5+\u0006\u0002\u0002pA!\u0011\u0011OA>\u001b\t\t\u0019H\u0003\u0003\u0002v\u0005]\u0014\u0001\u00027b]\u001eT!!!\u001f\u0002\t)\fg/Y\u0005\u0004c\u0005M\u0004\"CA@\u0003\u001f\u0002\u000b\u0011BA8\u0003\u0019a%IR$TA!Y\u00111QA(\u0005\u0004%\tAAA7\u0003\t9E\tC\u0005\u0002\b\u0006=\u0003\u0015!\u0003\u0002p\u0005\u0019q\t\u0012\u0011\t\u0017\u0005-\u0015q\nb\u0001\n\u0003\u0011\u0011QR\u0001\u0011gV\u0004\bo\u001c:uK\u0012\u001cv\u000e\u001c<feN,\"!a$\u0011\t-z\u0015q\u000e\u0005\n\u0003'\u000by\u0005)A\u0005\u0003\u001f\u000b\u0011c];qa>\u0014H/\u001a3T_24XM]:!\u0011!\t9*a\u0014\u0005B\u0005e\u0015\u0001\u00027pC\u0012$2\u0001GAN\u0011\u001d\ti*!&A\u0002%\nA\u0001]1uQ\"\"\u0011Q\u0013\u001b`\u0011)\t\u0019+a\u0014\u0002\u0002\u0013%\u0011QU\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002(B!\u0011\u0011OAU\u0013\u0011\tY+a\u001d\u0003\r=\u0013'.Z2uQ\u0011\ty\u0005N0)\t\u0005%Cg\u0018")
/* loaded from: input_file:org/apache/spark/ml/classification/MultilayerPerceptronClassifier.class */
public class MultilayerPerceptronClassifier extends Predictor<Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel> implements MultilayerPerceptronParams, DefaultParamsWritable {
    private final String uid;
    private final IntArrayParam layers;
    private final IntParam blockSize;
    private final Param<String> solver;
    private final Param<Vector> initialWeights;
    private final DoubleParam stepSize;
    private final DoubleParam tol;
    private final IntParam maxIter;
    private final LongParam seed;

    public static MLReader<MultilayerPerceptronClassifier> read() {
        return MultilayerPerceptronClassifier$.MODULE$.read();
    }

    public static MultilayerPerceptronClassifier load(String str) {
        return MultilayerPerceptronClassifier$.MODULE$.load(str);
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        return DefaultParamsWritable.Cclass.write(this);
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        MLWritable.Cclass.save(this, str);
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final IntArrayParam layers() {
        return this.layers;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final IntParam blockSize() {
        return this.blockSize;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams, org.apache.spark.ml.param.shared.HasSolver
    public final Param<String> solver() {
        return this.solver;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final Param<Vector> initialWeights() {
        return this.initialWeights;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$layers_$eq(IntArrayParam intArrayParam) {
        this.layers = intArrayParam;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$blockSize_$eq(IntParam intParam) {
        this.blockSize = intParam;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$solver_$eq(Param param) {
        this.solver = param;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$initialWeights_$eq(Param param) {
        this.initialWeights = param;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final int[] getLayers() {
        return MultilayerPerceptronParams.Cclass.getLayers(this);
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final int getBlockSize() {
        return MultilayerPerceptronParams.Cclass.getBlockSize(this);
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final Vector getInitialWeights() {
        return MultilayerPerceptronParams.Cclass.getInitialWeights(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasSolver
    public void org$apache$spark$ml$param$shared$HasSolver$_setter_$solver_$eq(Param param) {
    }

    @Override // org.apache.spark.ml.param.shared.HasSolver
    public final String getSolver() {
        return HasSolver.Cclass.getSolver(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasStepSize
    public final DoubleParam stepSize() {
        return this.stepSize;
    }

    @Override // org.apache.spark.ml.param.shared.HasStepSize
    public final void org$apache$spark$ml$param$shared$HasStepSize$_setter_$stepSize_$eq(DoubleParam doubleParam) {
        this.stepSize = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasStepSize
    public final double getStepSize() {
        return HasStepSize.Cclass.getStepSize(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam doubleParam) {
        this.tol = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final double getTol() {
        return HasTol.Cclass.getTol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam intParam) {
        this.maxIter = intParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final int getMaxIter() {
        return HasMaxIter.Cclass.getMaxIter(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasSeed
    public final LongParam seed() {
        return this.seed;
    }

    @Override // org.apache.spark.ml.param.shared.HasSeed
    public final void org$apache$spark$ml$param$shared$HasSeed$_setter_$seed_$eq(LongParam longParam) {
        this.seed = longParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasSeed
    public final long getSeed() {
        return HasSeed.Cclass.getSeed(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public MultilayerPerceptronClassifier setLayers(int[] iArr) {
        return (MultilayerPerceptronClassifier) set((Param<IntArrayParam>) layers(), (IntArrayParam) iArr);
    }

    public MultilayerPerceptronClassifier setBlockSize(int i) {
        return (MultilayerPerceptronClassifier) set((Param<IntParam>) blockSize(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public MultilayerPerceptronClassifier setSolver(String str) {
        return (MultilayerPerceptronClassifier) set((Param<Param<String>>) solver(), (Param<String>) str);
    }

    public MultilayerPerceptronClassifier setMaxIter(int i) {
        return (MultilayerPerceptronClassifier) set((Param<IntParam>) maxIter(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public MultilayerPerceptronClassifier setTol(double d) {
        return (MultilayerPerceptronClassifier) set((Param<DoubleParam>) tol(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    public MultilayerPerceptronClassifier setSeed(long j) {
        return (MultilayerPerceptronClassifier) set((Param<LongParam>) seed(), (LongParam) BoxesRunTime.boxToLong(j));
    }

    public MultilayerPerceptronClassifier setInitialWeights(Vector vector) {
        return (MultilayerPerceptronClassifier) set((Param<Param<Vector>>) initialWeights(), (Param<Vector>) vector);
    }

    public MultilayerPerceptronClassifier setStepSize(double d) {
        return (MultilayerPerceptronClassifier) set((Param<DoubleParam>) stepSize(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    @Override // org.apache.spark.ml.Predictor, org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public MultilayerPerceptronClassifier copy(ParamMap paramMap) {
        return (MultilayerPerceptronClassifier) defaultCopy(paramMap);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.ml.Predictor
    public MultilayerPerceptronClassificationModel train(Dataset<?> dataset) {
        Instrumentation create = Instrumentation$.MODULE$.create((Instrumentation$) this, dataset);
        create.logParams(Predef$.MODULE$.wrapRefArray(new Param[]{labelCol(), featuresCol(), predictionCol(), layers(), maxIter(), tol(), blockSize(), solver(), stepSize(), seed()}));
        int[] iArr = (int[]) $(layers());
        int unboxToInt = BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).last());
        create.logNumClasses(unboxToInt);
        create.logNumFeatures(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).head()));
        RDD<Tuple2<Vector, Vector>> map = extractLabeledPoints(dataset).map(new MultilayerPerceptronClassifier$$anonfun$3(this, unboxToInt), ClassTag$.MODULE$.apply(Tuple2.class));
        FeedForwardTrainer feedForwardTrainer = new FeedForwardTrainer(FeedForwardTopology$.MODULE$.multiLayerPerceptron(iArr, true), iArr[0], BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).last()));
        if (isDefined(initialWeights())) {
            feedForwardTrainer.setWeights((Vector) $(initialWeights()));
        } else {
            feedForwardTrainer.setSeed(BoxesRunTime.unboxToLong($(seed())));
        }
        Object $ = $(solver());
        String LBFGS = MultilayerPerceptronClassifier$.MODULE$.LBFGS();
        if ($ != null ? !$.equals(LBFGS) : LBFGS != null) {
            Object $2 = $(solver());
            String GD = MultilayerPerceptronClassifier$.MODULE$.GD();
            if ($2 != null ? !$2.equals(GD) : GD != null) {
                throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"The solver ", " is not supported by MultilayerPerceptronClassifier."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{solver()})));
            }
            feedForwardTrainer.SGDOptimizer().setNumIterations(BoxesRunTime.unboxToInt($(maxIter()))).setConvergenceTol(BoxesRunTime.unboxToDouble($(tol()))).setStepSize(BoxesRunTime.unboxToDouble($(stepSize())));
        } else {
            feedForwardTrainer.LBFGSOptimizer().setConvergenceTol(BoxesRunTime.unboxToDouble($(tol()))).setNumIterations(BoxesRunTime.unboxToInt($(maxIter())));
        }
        feedForwardTrainer.setStackSize(BoxesRunTime.unboxToInt($(blockSize())));
        MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel = new MultilayerPerceptronClassificationModel(uid(), iArr, feedForwardTrainer.train(map).weights());
        create.logSuccess(multilayerPerceptronClassificationModel);
        return multilayerPerceptronClassificationModel;
    }

    @Override // org.apache.spark.ml.Predictor
    public /* bridge */ /* synthetic */ MultilayerPerceptronClassificationModel train(Dataset dataset) {
        return train((Dataset<?>) dataset);
    }

    public MultilayerPerceptronClassifier(String str) {
        this.uid = str;
        HasSeed.Cclass.$init$(this);
        org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gtEq(0.0d)));
        org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms (>= 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gtEq(0.0d)));
        org$apache$spark$ml$param$shared$HasStepSize$_setter_$stepSize_$eq(new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization (> 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gt(0.0d)));
        org$apache$spark$ml$param$shared$HasSolver$_setter_$solver_$eq(new Param(this, "solver", "the solver algorithm for optimization"));
        MultilayerPerceptronParams.Cclass.$init$(this);
        MLWritable.Cclass.$init$(this);
        DefaultParamsWritable.Cclass.$init$(this);
    }

    public MultilayerPerceptronClassifier() {
        this(Identifiable$.MODULE$.randomUID("mlpc"));
    }
}
