/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import java.lang.invoke.LambdaMetafactory;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.impl.Utils$;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseMatrix$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Matrices$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.SparseMatrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function3;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u0005\u0015a!B\n\u0015\u0001a\u0001\u0003\u0002\u0003\u001d\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u001e\t\u0011u\u0002!\u0011!Q\u0001\niB\u0001B\u0010\u0001\u0003\u0002\u0003\u0006Ia\u0010\u0005\t\u0005\u0002\u0011\t\u0011)A\u0005\u007f!A1\t\u0001B\u0001B\u0003%A\tC\u0003Q\u0001\u0011\u0005\u0011\u000bC\u0004Y\u0001\t\u0007I\u0011B-\t\ri\u0003\u0001\u0015!\u0003;\u0011\u001dY\u0006A1A\u0005\neCa\u0001\u0018\u0001!\u0002\u0013Q\u0004bB/\u0001\u0005\u0004%\t&\u0017\u0005\u0007=\u0002\u0001\u000b\u0011\u0002\u001e\t\u0011}\u0003\u0001R1A\u0005\n\u0001D\u0001b\u001b\u0001\t\u0006\u0004%I\u0001\u001c\u0005\t]\u0002A)\u0019!C\u0005_\")A\u000f\u0001C\u0001k\")\u0011\u0010\u0001C\u0005u\"1q\u0010\u0001C\u0005\u0003\u0003\u0011qC\u00117pG.dunZ5ti&\u001c\u0017iZ4sK\u001e\fGo\u001c:\u000b\u0005U1\u0012AC1hOJ,w-\u0019;pe*\u0011q\u0003G\u0001\u0006_B$\u0018.\u001c\u0006\u00033i\t!!\u001c7\u000b\u0005ma\u0012!B:qCJ\\'BA\u000f\u001f\u0003\u0019\t\u0007/Y2iK*\tq$A\u0002pe\u001e\u001cB\u0001A\u0011(eA\u0011!%J\u0007\u0002G)\tA%A\u0003tG\u0006d\u0017-\u0003\u0002'G\t1\u0011I\\=SK\u001a\u0004B\u0001K\u0015,c5\tA#\u0003\u0002+)\taB)\u001b4gKJ,g\u000e^5bE2,Gj\\:t\u0003\u001e<'/Z4bi>\u0014\bC\u0001\u00170\u001b\u0005i#B\u0001\u0018\u0019\u0003\u001d1W-\u0019;ve\u0016L!\u0001M\u0017\u0003\u001b%s7\u000f^1oG\u0016\u0014En\\2l!\tA\u0003\u0001\u0005\u00024m5\tAG\u0003\u000265\u0005A\u0011N\u001c;fe:\fG.\u0003\u00028i\t9Aj\\4hS:<\u0017a\u00038v[\u001a+\u0017\r^;sKN\u001c\u0001\u0001\u0005\u0002#w%\u0011Ah\t\u0002\u0004\u0013:$\u0018A\u00038v[\u000ec\u0017m]:fg\u0006aa-\u001b;J]R,'oY3qiB\u0011!\u0005Q\u0005\u0003\u0003\u000e\u0012qAQ8pY\u0016\fg.A\u0006nk2$\u0018N\\8nS\u0006d\u0017A\u00042d\u0007>,gMZ5dS\u0016tGo\u001d\t\u0004\u000b\"SU\"\u0001$\u000b\u0005\u001dS\u0012!\u00032s_\u0006$7-Y:u\u0013\tIeIA\u0005Ce>\fGmY1tiB\u00111JT\u0007\u0002\u0019*\u0011Q\nG\u0001\u0007Y&t\u0017\r\\4\n\u0005=c%A\u0002,fGR|'/\u0001\u0004=S:LGO\u0010\u000b\u0006%R+fk\u0016\u000b\u0003cMCQa\u0011\u0004A\u0002\u0011CQ\u0001\u000f\u0004A\u0002iBQ!\u0010\u0004A\u0002iBQA\u0010\u0004A\u0002}BQA\u0011\u0004A\u0002}\n\u0001D\\;n\r\u0016\fG/\u001e:fgBcWo]%oi\u0016\u00148-\u001a9u+\u0005Q\u0014!\u00078v[\u001a+\u0017\r^;sKN\u0004F.^:J]R,'oY3qi\u0002\nqbY8fM\u001aL7-[3oiNK'0Z\u0001\u0011G>,gMZ5dS\u0016tGoU5{K\u0002\n1\u0001Z5n\u0003\u0011!\u0017.\u001c\u0011\u0002#\r|WM\u001a4jG&,g\u000e^:BeJ\f\u00170F\u0001b!\r\u0011#\rZ\u0005\u0003G\u000e\u0012Q!\u0011:sCf\u0004\"AI3\n\u0005\u0019\u001c#A\u0002#pk\ndW\r\u000b\u0002\u000eQB\u0011!%[\u0005\u0003U\u000e\u0012\u0011\u0002\u001e:b]NLWM\u001c;\u0002\u0019\tLg.\u0019:z\u0019&tW-\u0019:\u0016\u0003)C#A\u00045\u0002#5,H\u000e^5o_6L\u0017\r\u001c'j]\u0016\f'/F\u0001q!\tY\u0015/\u0003\u0002s\u0019\nYA)\u001a8tK6\u000bGO]5yQ\ty\u0001.A\u0002bI\u0012$\"A^<\u000e\u0003\u0001AQ\u0001\u001f\tA\u0002-\nQA\u00197pG.\f1CY5oCJLX\u000b\u001d3bi\u0016Le\u000e\u00157bG\u0016$\"a\u001f@\u0011\u0005\tb\u0018BA?$\u0005\u0011)f.\u001b;\t\u000ba\f\u0002\u0019A\u0016\u000215,H\u000e^5o_6L\u0017\r\\+qI\u0006$X-\u00138QY\u0006\u001cW\rF\u0002|\u0003\u0007AQ\u0001\u001f\nA\u0002-\u0002")
public class BlockLogisticAggregator
implements DifferentiableLossAggregator<InstanceBlock, BlockLogisticAggregator>,
Logging {
    private transient double[] coefficientsArray;
    private transient Vector binaryLinear;
    private transient DenseMatrix multinomialLinear;
    private final int numFeatures;
    private final int numClasses;
    private final boolean fitIntercept;
    private final boolean multinomial;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeaturesPlusIntercept;
    private final int coefficientSize;
    private final int dim;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient byte bitmap$trans$0;
    private volatile boolean bitmap$0;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        BlockLogisticAggregator blockLogisticAggregator = this;
        synchronized (blockLogisticAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    private int numFeaturesPlusIntercept() {
        return this.numFeaturesPlusIntercept;
    }

    private int coefficientSize() {
        return this.coefficientSize;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private double[] coefficientsArray$lzycompute() {
        BlockLogisticAggregator blockLogisticAggregator = this;
        synchronized (blockLogisticAggregator) {
            if ((byte)(this.bitmap$trans$0 & 1) == 0) {
                double[] values;
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(55).append("coefficients only supports dense vector but ").append("got type ").append(this.bcCoefficients.value().getClass()).append(".)").toString());
                }
                double[] dArray = values = (double[])option.get();
                this.coefficientsArray = dArray;
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 1);
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return (byte)(this.bitmap$trans$0 & 1) == 0 ? this.coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    /*
     * Unable to fully structure code
     */
    private Vector binaryLinear$lzycompute() {
        block5: {
            var2_1 = this;
            synchronized (var2_1) {
                block7: {
                    block6: {
                        if ((byte)(this.bitmap$trans$0 & 2) != 0) break block5;
                        var3_2 = new Tuple2.mcZZ.sp(this.multinomial, this.fitIntercept);
                        if (var3_2 == null) break block6;
                        var4_3 = var3_2._1$mcZ$sp();
                        var5_4 = var3_2._2$mcZ$sp();
                        if (var4_3 || !var5_4) break block6;
                        var1_5 = Vectors$.MODULE$.dense((double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).take(this.numFeatures));
                        break block7;
                    }
                    if (var3_2 == null) ** GOTO lbl-1000
                    var6_6 = var3_2._1$mcZ$sp();
                    var7_7 = var3_2._2$mcZ$sp();
                    if (!var6_6 && !var7_7) {
                        var1_5 = Vectors$.MODULE$.dense(this.coefficientsArray());
                    } else lbl-1000:
                    // 2 sources

                    {
                        var1_5 = null;
                    }
                }
                this.binaryLinear = var1_5;
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 2);
            }
        }
        return this.binaryLinear;
    }

    private Vector binaryLinear() {
        return (byte)(this.bitmap$trans$0 & 2) == 0 ? this.binaryLinear$lzycompute() : this.binaryLinear;
    }

    /*
     * Unable to fully structure code
     */
    private DenseMatrix multinomialLinear$lzycompute() {
        block5: {
            var2_1 = this;
            synchronized (var2_1) {
                block7: {
                    block6: {
                        if ((byte)(this.bitmap$trans$0 & 4) != 0) break block5;
                        var3_2 = new Tuple2.mcZZ.sp(this.multinomial, this.fitIntercept);
                        if (var3_2 == null) break block6;
                        var4_3 = var3_2._1$mcZ$sp();
                        var5_4 = var3_2._2$mcZ$sp();
                        if (!var4_3 || !var5_4) break block6;
                        var1_5 = Matrices$.MODULE$.dense(this.numClasses, this.numFeatures, (double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).take(this.numClasses * this.numFeatures)).toDense();
                        break block7;
                    }
                    if (var3_2 == null) ** GOTO lbl-1000
                    var6_6 = var3_2._1$mcZ$sp();
                    var7_7 = var3_2._2$mcZ$sp();
                    if (var6_6 && !var7_7) {
                        var1_5 = Matrices$.MODULE$.dense(this.numClasses, this.numFeatures, this.coefficientsArray()).toDense();
                    } else lbl-1000:
                    // 2 sources

                    {
                        var1_5 = null;
                    }
                }
                this.multinomialLinear = var1_5;
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 4);
            }
        }
        return this.multinomialLinear;
    }

    private DenseMatrix multinomialLinear() {
        return (byte)(this.bitmap$trans$0 & 4) == 0 ? this.multinomialLinear$lzycompute() : this.multinomialLinear;
    }

    @Override
    public BlockLogisticAggregator add(InstanceBlock block) {
        Predef$.MODULE$.require(block.matrix().isTransposed());
        Predef$.MODULE$.require(this.numFeatures == block.numFeatures(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append($this.numFeatures).append(" but got ").append(block.numFeatures()).append(".").toString());
        Predef$.MODULE$.require(block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$1 -> x$1 >= 0.0), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(34).append("instance weights ").append(block.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString());
        if (block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$2 -> x$2 == 0.0)) {
            return this;
        }
        if (this.multinomial) {
            this.multinomialUpdateInPlace(block);
        } else {
            this.binaryUpdateInPlace(block);
        }
        return this;
    }

    /*
     * Unable to fully structure code
     */
    private void binaryUpdateInPlace(InstanceBlock block) {
        block9: {
            block8: {
                block7: {
                    size = block.size();
                    vec = this.fitIntercept != false ? Vectors$.MODULE$.dense((double[])Array$.MODULE$.fill(size, (Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)LambdaMetafactory.altMetafactory(null, null, null, ()D, $anonfun$binaryUpdateInPlace$3(org.apache.spark.ml.optim.aggregator.BlockLogisticAggregator ), ()D)((BlockLogisticAggregator)this), ClassTag$.MODULE$.Double())).toDense() : Vectors$.MODULE$.zeros(size).toDense();
                    BLAS$.MODULE$.gemv(-1.0, block.matrix(), this.binaryLinear(), -1.0, vec);
                    localLossSum = 0.0;
                    for (i = 0; i < size; ++i) {
                        weight = block.getWeight().apply$mcDI$sp(i);
                        if (weight > (double)false) {
                            label = block.getLabel(i);
                            margin = vec.apply(i);
                            localLossSum = label > (double)false ? (localLossSum += weight * Utils$.MODULE$.log1pExp(margin)) : (localLossSum += weight * (Utils$.MODULE$.log1pExp(margin) - margin));
                            vec.values()[i] = multiplier = weight * (1.0 / (1.0 + package$.MODULE$.exp(margin)) - label);
                            continue;
                        }
                        vec.values()[i] = 0.0;
                    }
                    this.lossSum_$eq(this.lossSum() + localLossSum);
                    this.weightSum_$eq(this.weightSum() + BoxesRunTime.unboxToDouble((Object)block.weightIter().sum((Numeric)Numeric.DoubleIsFractional$.MODULE$)));
                    if (new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(vec.values())).forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)LambdaMetafactory.altMetafactory(null, null, null, (D)Z, $anonfun$binaryUpdateInPlace$4(double ), (D)Z)())) {
                        return;
                    }
                    var16_10 = false;
                    var17_11 = null;
                    var18_12 = block.matrix();
                    if (!(var18_12 instanceof DenseMatrix)) break block7;
                    var19_13 = (DenseMatrix)var18_12;
                    BLAS$.MODULE$.nativeBLAS().dgemv("N", var19_13.numCols(), var19_13.numRows(), 1.0, var19_13.values(), var19_13.numCols(), vec.values(), 1, 1.0, this.gradientSumArray(), 1);
                    var2_14 = BoxedUnit.UNIT;
                    break block8;
                }
                if (!(var18_12 instanceof SparseMatrix)) ** GOTO lbl-1000
                var16_10 = true;
                var17_11 = (SparseMatrix)var18_12;
                if (this.fitIntercept) {
                    linearGradSumVec = Vectors$.MODULE$.zeros(this.numFeatures).toDense();
                    BLAS$.MODULE$.gemv(1.0, (Matrix)var17_11.transpose(), (Vector)vec, 0.0, linearGradSumVec);
                    BLAS$.MODULE$.getBLAS(this.numFeatures).daxpy(this.numFeatures, 1.0, linearGradSumVec.values(), 1, this.gradientSumArray(), 1);
                    var2_15 = BoxedUnit.UNIT;
                } else if (var16_10 && !this.fitIntercept) {
                    gradSumVec = new DenseVector(this.gradientSumArray());
                    BLAS$.MODULE$.gemv(1.0, (Matrix)var17_11.transpose(), (Vector)vec, 1.0, gradSumVec);
                    var2_16 = BoxedUnit.UNIT;
                } else {
                    throw new IllegalArgumentException(new StringBuilder(21).append("Unknown matrix type ").append(var18_12.getClass()).append(".").toString());
                }
            }
            if (!this.fitIntercept) break block9;
            this.gradientSumArray()[this.numFeatures] = this.gradientSumArray()[this.numFeatures] + BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(vec.values())).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
        }
    }

    private void multinomialUpdateInPlace(InstanceBlock block) {
        block15: {
            double[] interceptGradSumArr;
            int size = block.size();
            DenseMatrix mat = DenseMatrix$.MODULE$.zeros(size, this.numClasses);
            if (this.fitIntercept) {
                double[] localCoefficientsArray = this.coefficientsArray();
                int offset = this.numClasses * this.numFeatures;
                for (int j2 = 0; j2 < this.numClasses; ++j2) {
                    double intercept = localCoefficientsArray[offset + j2];
                    for (int i2 = 0; i2 < size; ++i2) {
                        mat.update(i2, j2, intercept);
                    }
                }
            }
            BLAS$.MODULE$.gemm(1.0, block.matrix(), this.multinomialLinear().transpose(), 1.0, mat);
            double localLossSum = 0.0;
            double[] tmp = (double[])Array$.MODULE$.ofDim(this.numClasses, ClassTag$.MODULE$.Double());
            double[] dArray = interceptGradSumArr = this.fitIntercept ? (double[])Array$.MODULE$.ofDim(this.numClasses, ClassTag$.MODULE$.Double()) : null;
            for (int i3 = 0; i3 < size; ++i3) {
                double weight = block.getWeight().apply$mcDI$sp(i3);
                if (weight > 0.0) {
                    int j3;
                    double label = block.getLabel(i3);
                    double maxMargin = Double.NEGATIVE_INFINITY;
                    for (j3 = 0; j3 < this.numClasses; ++j3) {
                        tmp[j3] = mat.apply(i3, j3);
                        maxMargin = package$.MODULE$.max(maxMargin, tmp[j3]);
                    }
                    double marginOfLabel = tmp[(int)label];
                    double sum = 0.0;
                    for (j3 = 0; j3 < this.numClasses; ++j3) {
                        if (maxMargin > 0.0) {
                            int n = j3;
                            tmp[n] = tmp[n] - maxMargin;
                        }
                        double exp = package$.MODULE$.exp(tmp[j3]);
                        sum += exp;
                        tmp[j3] = exp;
                    }
                    for (j3 = 0; j3 < this.numClasses; ++j3) {
                        double multiplier = weight * (tmp[j3] / sum - (label == (double)j3 ? 1.0 : 0.0));
                        mat.update(i3, j3, multiplier);
                        if (!this.fitIntercept) continue;
                        int n = j3;
                        interceptGradSumArr[n] = interceptGradSumArr[n] + multiplier;
                    }
                    if (maxMargin > 0.0) {
                        localLossSum += weight * (package$.MODULE$.log(sum) - marginOfLabel + maxMargin);
                        continue;
                    }
                    localLossSum += weight * (package$.MODULE$.log(sum) - marginOfLabel);
                    continue;
                }
                for (int j4 = 0; j4 < this.numClasses; ++j4) {
                    mat.update(i3, j4, 0.0);
                }
            }
            this.lossSum_$eq(this.lossSum() + localLossSum);
            this.weightSum_$eq(this.weightSum() + BoxesRunTime.unboxToDouble((Object)block.weightIter().sum((Numeric)Numeric.DoubleIsFractional$.MODULE$)));
            Matrix matrix = block.matrix();
            if (matrix instanceof DenseMatrix) {
                DenseMatrix denseMatrix = (DenseMatrix)matrix;
                BLAS$.MODULE$.nativeBLAS().dgemm("T", "T", this.numClasses, this.numFeatures, size, 1.0, mat.values(), size, denseMatrix.values(), this.numFeatures, 1.0, this.gradientSumArray(), this.numClasses);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (matrix instanceof SparseMatrix) {
                SparseMatrix sparseMatrix = (SparseMatrix)matrix;
                DenseMatrix linearGradSumMat = DenseMatrix$.MODULE$.zeros(this.numFeatures, this.numClasses);
                BLAS$.MODULE$.gemm(1.0, (Matrix)sparseMatrix.transpose(), mat, 0.0, linearGradSumMat);
                linearGradSumMat.foreachActive((Function3 & Serializable & scala.Serializable)(i, j, v) -> {
                    BlockLogisticAggregator.$anonfun$multinomialUpdateInPlace$4(this, BoxesRunTime.unboxToInt((Object)i), BoxesRunTime.unboxToInt((Object)j), BoxesRunTime.unboxToDouble((Object)v));
                    return BoxedUnit.UNIT;
                });
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                throw new MatchError((Object)matrix);
            }
            if (!this.fitIntercept) break block15;
            BLAS$.MODULE$.getBLAS(this.numClasses).daxpy(this.numClasses, 1.0, interceptGradSumArr, 0, 1, this.gradientSumArray(), this.numClasses * this.numFeatures, 1);
        }
    }

    public static final /* synthetic */ double $anonfun$binaryUpdateInPlace$3(BlockLogisticAggregator $this) {
        return BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps($this.coefficientsArray())).last());
    }

    public static final /* synthetic */ boolean $anonfun$binaryUpdateInPlace$4(double x$3) {
        return x$3 == 0.0;
    }

    public static final /* synthetic */ void $anonfun$multinomialUpdateInPlace$4(BlockLogisticAggregator $this, int i, int j, double v) {
        int n = i * $this.numClasses + j;
        $this.gradientSumArray()[n] = $this.gradientSumArray()[n] + v;
    }

    public BlockLogisticAggregator(int numFeatures, int numClasses, boolean fitIntercept, boolean multinomial, Broadcast<Vector> bcCoefficients) {
        this.numFeatures = numFeatures;
        this.numClasses = numClasses;
        this.fitIntercept = fitIntercept;
        this.multinomial = multinomial;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        Logging.$init$((Logging)this);
        if (multinomial && numClasses <= 2) {
            this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(324).append("Multinomial logistic regression for binary classification yields separate ").append("coefficients for positive and negative classes. When no regularization is applied, the").append("result will be effectively the same as binary logistic regression. When regularization").append("is applied, multinomial loss will produce a result different from binary loss.").toString());
        }
        this.numFeaturesPlusIntercept = fitIntercept ? numFeatures + 1 : numFeatures;
        this.coefficientSize = ((Vector)bcCoefficients.value()).size();
        this.dim = this.coefficientSize();
        if (multinomial) {
            Predef$.MODULE$.require(numClasses == this.coefficientSize() / this.numFeaturesPlusIntercept(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(46).append("The number of ").append("coefficients should be ").append($this.numClasses * this.numFeaturesPlusIntercept()).append(" but was ").append(this.coefficientSize()).toString());
        } else {
            Predef$.MODULE$.require(this.coefficientSize() == this.numFeaturesPlusIntercept(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(31).append("Expected ").append(this.numFeaturesPlusIntercept()).append(" ").append("coefficients but got ").append(this.coefficientSize()).toString());
            Predef$.MODULE$.require(numClasses == 1 || numClasses == 2, (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(68).append("Binary logistic aggregator requires numClasses ").append("in {1, 2} but found ").append($this.numClasses).append(".").toString());
        }
    }
}

