package ca.pfv.spmf.algorithms.classifiers.logisticregression;

import ca.pfv.spmf.tools.MemoryLogger;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:ca/pfv/spmf/algorithms/classifiers/logisticregression/AlgoBinaryLogisticRegression.class */
public class AlgoBinaryLogisticRegression {
    double bias;
    double[] weights = null;
    int iterationCount = 1000;
    double learningRate = 0.1d;
    int totalNumberIterations = 0;
    long totalTime = 0;
    double totalMemory = 0.0d;

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean predictBoolean(InstanceContinuous instanceContinuous) {
        return weightedSum(instanceContinuous.values) > 0.5d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double predictDouble(InstanceContinuous instanceContinuous) {
        return weightedSum(instanceContinuous.values);
    }

    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    double sigmoid(double d) {
        return 1.0d / (1.0d + Math.pow(2.718281828459045d, (-1.0d) * d));
    }

    private double weightedSum(double[] dArr) {
        double d = this.bias;
        for (int i = 0; i < this.weights.length; i++) {
            d += dArr[i] * this.weights[i];
        }
        return sigmoid(d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void train(List<InstanceContinuous> list, List<Boolean> list2) {
        this.totalNumberIterations = 0;
        this.totalTime = System.currentTimeMillis();
        MemoryLogger.getInstance().reset();
        int length = list.get(0).values.length;
        this.weights = new double[length];
        this.bias = Math.random();
        double[] dArr = new double[length];
        double size = this.learningRate / list.size();
        double d = this.learningRate / 3.0d;
        double d2 = -d;
        for (int i = 0; i < this.iterationCount; i++) {
            Arrays.fill(dArr, 0.0d);
            double d3 = 0.0d;
            for (int i2 = 0; i2 < list.size(); i2++) {
                double[] dArr2 = list.get(i2).values;
                double d4 = list2.get(i2).booleanValue() ? 1.0d : 0.0d;
                double weightedSum = weightedSum(dArr2);
                for (int i3 = 0; i3 < length; i3++) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] - ((weightedSum - d4) * dArr2[i3]);
                }
                d3 -= weightedSum - d4;
            }
            for (int i5 = 0; i5 < length; i5++) {
                double[] dArr3 = this.weights;
                int i6 = i5;
                dArr3[i6] = dArr3[i6] + (size * dArr[i5]);
            }
            this.bias += size * d3;
            this.totalNumberIterations++;
            if (d3 < d && d3 > d2) {
                break;
            }
        }
        MemoryLogger.getInstance().checkMemory();
        this.totalTime = System.currentTimeMillis() - this.totalTime;
        this.totalMemory = MemoryLogger.getInstance().getMaxMemory();
    }

    public void printStats() {
        System.out.println("=============  BinaryLogisticRegression v.2.53 - STATS =============");
        System.out.println(" Stopped at " + this.totalNumberIterations + " iterations.");
        System.out.println(" Total time ~ " + this.totalTime + " ms");
        System.out.println(" Maximum memory usage : " + this.totalMemory + " mb");
        System.out.println("===================================================");
    }
}
