package ca.pfv.spmf.algorithms.clustering.kmeans;

import ca.pfv.spmf.algorithms.clustering.distanceFunctions.DistanceFunction;
import ca.pfv.spmf.algorithms.clustering.instancereader.AlgoInstanceFileReader;
import ca.pfv.spmf.patterns.cluster.ClusterWithMean;
import ca.pfv.spmf.patterns.cluster.ClustersEvaluation;
import ca.pfv.spmf.patterns.cluster.DoubleArray;
import ca.pfv.spmf.tools.MemoryLogger;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:ca/pfv/spmf/algorithms/clustering/kmeans/AlgoKMeans.class */
public class AlgoKMeans {
    protected static final Random random = new Random(System.currentTimeMillis());
    protected long startTimestamp;
    protected long endTimestamp;
    long iterationCount;
    protected List<ClusterWithMean> clusters = null;
    protected DistanceFunction distanceFunction = null;
    private List<String> attributeNames = null;
    boolean DEBUG_MODE = false;

    public List<ClusterWithMean> runAlgorithm(String str, int i, DistanceFunction distanceFunction, String str2) throws NumberFormatException, IOException {
        this.startTimestamp = System.currentTimeMillis();
        this.iterationCount = 0L;
        this.distanceFunction = distanceFunction;
        double d = 2.147483647E9d;
        double d2 = 0.0d;
        AlgoInstanceFileReader algoInstanceFileReader = new AlgoInstanceFileReader();
        List<DoubleArray> runAlgorithm = algoInstanceFileReader.runAlgorithm(str, str2);
        algoInstanceFileReader.getAttributeNames().size();
        this.attributeNames = algoInstanceFileReader.getAttributeNames();
        Iterator<DoubleArray> it = runAlgorithm.iterator();
        while (it.hasNext()) {
            for (double d3 : it.next().data) {
                if (d3 < d) {
                    d = d3;
                }
                if (d3 > d2) {
                    d2 = d3;
                }
            }
        }
        int length = runAlgorithm.get(0).data.length;
        if (i == 1) {
            this.clusters = new ArrayList();
            ClusterWithMean clusterWithMean = new ClusterWithMean(length);
            Iterator<DoubleArray> it2 = runAlgorithm.iterator();
            while (it2.hasNext()) {
                clusterWithMean.addVector(it2.next());
            }
            clusterWithMean.setMean(new DoubleArray(new double[length]));
            clusterWithMean.recomputeClusterMean();
            this.clusters.add(clusterWithMean);
            MemoryLogger.getInstance().checkMemory();
            this.endTimestamp = System.currentTimeMillis();
            return this.clusters;
        }
        if (runAlgorithm.size() != 1) {
            if (i > runAlgorithm.size()) {
                i = runAlgorithm.size();
            }
            applyAlgorithm(i, distanceFunction, runAlgorithm, d, d2, length);
            MemoryLogger.getInstance().checkMemory();
            this.endTimestamp = System.currentTimeMillis();
            return this.clusters;
        }
        this.clusters = new ArrayList();
        DoubleArray doubleArray = runAlgorithm.get(0);
        ClusterWithMean clusterWithMean2 = new ClusterWithMean(length);
        clusterWithMean2.addVector(doubleArray);
        clusterWithMean2.recomputeClusterMean();
        clusterWithMean2.setMean(new DoubleArray(new double[length]));
        this.clusters.add(clusterWithMean2);
        MemoryLogger.getInstance().checkMemory();
        this.endTimestamp = System.currentTimeMillis();
        return this.clusters;
    }

    void applyAlgorithm(int i, DistanceFunction distanceFunction, List<DoubleArray> list, double d, double d2, int i2) {
        this.clusters = applyKMeans(i, distanceFunction, list, d, d2, i2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<ClusterWithMean> applyKMeans(int i, DistanceFunction distanceFunction, List<DoubleArray> list, double d, double d2, int i2) {
        ArrayList arrayList = new ArrayList();
        if (list.size() == 1) {
            DoubleArray doubleArray = list.get(0);
            ClusterWithMean clusterWithMean = new ClusterWithMean(i2);
            clusterWithMean.addVector(doubleArray);
            arrayList.add(clusterWithMean);
            return arrayList;
        }
        initializeCentroids(list, i, i2, arrayList);
        if (this.DEBUG_MODE) {
            System.out.println("==== INPUT DATA =====");
            for (int i3 = 0; i3 < list.size(); i3++) {
                System.out.println("Instance " + i3 + ": " + String.valueOf(list.get(i3)));
            }
            System.out.println("==== INITIAL CENTROIDS =====");
            for (int i4 = 0; i4 < arrayList.size(); i4++) {
                System.out.println("Centroid " + i4 + ": " + String.valueOf(arrayList.get(i4)));
            }
        }
        while (true) {
            this.iterationCount++;
            if (this.DEBUG_MODE) {
                System.out.println("Iteration " + this.iterationCount);
            }
            boolean z = false;
            for (DoubleArray doubleArray2 : list) {
                ClusterWithMean clusterWithMean2 = null;
                ClusterWithMean clusterWithMean3 = null;
                double d3 = Double.MAX_VALUE;
                for (ClusterWithMean clusterWithMean4 : arrayList) {
                    double calculateDistance = distanceFunction.calculateDistance(clusterWithMean4.getMean(), doubleArray2);
                    if (calculateDistance < d3) {
                        clusterWithMean2 = clusterWithMean4;
                        d3 = calculateDistance;
                    }
                    if (clusterWithMean4.contains(doubleArray2)) {
                        clusterWithMean3 = clusterWithMean4;
                    }
                }
                if (clusterWithMean3 != clusterWithMean2) {
                    if (clusterWithMean3 != null) {
                        clusterWithMean3.remove(doubleArray2);
                    }
                    clusterWithMean2.addVector(doubleArray2);
                    if (this.DEBUG_MODE) {
                        System.out.println(" Instance " + String.valueOf(doubleArray2) + " is assigned to cluster  " + String.valueOf(clusterWithMean2.getMean()));
                    }
                    z = true;
                }
            }
            MemoryLogger.getInstance().checkMemory();
            if (!z) {
                break;
            }
            for (ClusterWithMean clusterWithMean5 : arrayList) {
                clusterWithMean5.recomputeClusterMean();
                if (this.DEBUG_MODE) {
                    System.out.println("Cluster mean: " + String.valueOf(clusterWithMean5.getMean()));
                }
            }
        }
        if (this.DEBUG_MODE) {
            System.out.println("Check : " + verifyClusterAssignments(arrayList, list, distanceFunction));
        }
        return arrayList;
    }

    public boolean verifyClusterAssignments(List<ClusterWithMean> list, List<DoubleArray> list2, DistanceFunction distanceFunction) {
        for (DoubleArray doubleArray : list2) {
            ClusterWithMean clusterWithMean = null;
            double d = Double.MAX_VALUE;
            for (ClusterWithMean clusterWithMean2 : list) {
                double calculateDistance = distanceFunction.calculateDistance(doubleArray, clusterWithMean2.getMean());
                if (calculateDistance < d) {
                    d = calculateDistance;
                    clusterWithMean = clusterWithMean2;
                }
            }
            if (!clusterWithMean.getVectors().contains(doubleArray)) {
                return false;
            }
        }
        return true;
    }

    private void initializeCentroids(List<DoubleArray> list, int i, int i2, List<ClusterWithMean> list2) {
        ArrayList<DoubleArray> arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        while (arrayList.size() < i) {
            int nextInt = random.nextInt(list.size());
            if (!hashSet.contains(Integer.valueOf(nextInt))) {
                arrayList.add(list.get(nextInt));
                hashSet.add(Integer.valueOf(nextInt));
            }
        }
        for (DoubleArray doubleArray : arrayList) {
            ClusterWithMean clusterWithMean = new ClusterWithMean(doubleArray.data.length);
            clusterWithMean.setMean(doubleArray);
            list2.add(clusterWithMean);
        }
    }

    public void saveToFile(String str) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str));
        Iterator<String> it = this.attributeNames.iterator();
        while (it.hasNext()) {
            bufferedWriter.write("@ATTRIBUTEDEF=" + it.next());
            bufferedWriter.newLine();
        }
        for (int i = 0; i < this.clusters.size(); i++) {
            if (this.clusters.get(i).getVectors().size() >= 1) {
                bufferedWriter.write(this.clusters.get(i).toString());
                if (i < this.clusters.size() - 1) {
                    bufferedWriter.newLine();
                }
            }
        }
        bufferedWriter.close();
    }

    public void printStatistics() {
        System.out.println("========== KMEANS - SPMF 2.09 - STATS ============");
        System.out.println(" Distance function: " + this.distanceFunction.getName());
        System.out.println(" Total time ~: " + (this.endTimestamp - this.startTimestamp) + " ms");
        System.out.println(" SSE (Sum of Squared Errors) (lower is better) : " + ClustersEvaluation.calculateSSE(this.clusters, this.distanceFunction));
        System.out.println(" Max memory:" + MemoryLogger.getInstance().getMaxMemory() + " mb ");
        System.out.println(" Iteration count: " + this.iterationCount);
        System.out.println("=====================================");
    }
}
