package ca.pfv.spmf.algorithms.classifiers.adt;

import ca.pfv.spmf.algorithms.ArraysAlgos;
import ca.pfv.spmf.algorithms.classifiers.data.Dataset;
import ca.pfv.spmf.algorithms.classifiers.data.Instance;
import ca.pfv.spmf.algorithms.classifiers.general.Rule;
import ca.pfv.spmf.algorithms.classifiers.general.RuleClassifier;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ca/pfv/spmf/algorithms/classifiers/adt/ClassifierADT.class */
public class ClassifierADT extends RuleClassifier implements Serializable {
    private static final long serialVersionUID = 8240202223112688265L;
    private Dataset training;
    private double minMerit;

    public ClassifierADT(List<RuleADT> list, double d, Dataset dataset) {
        super("ADT");
        this.training = dataset;
        this.minMerit = d;
        Collections.sort(list, new Comparator<RuleADT>() { // from class: ca.pfv.spmf.algorithms.classifiers.adt.ClassifierADT.1
            @Override // java.util.Comparator
            public int compare(RuleADT ruleADT, RuleADT ruleADT2) {
                if (Double.compare(ruleADT.getConfidence(), ruleADT2.getConfidence()) != 0) {
                    return -Double.compare(ruleADT.getConfidence(), ruleADT2.getConfidence());
                }
                if (Double.compare(ruleADT.getSupportRule(), ruleADT2.getSupportRule()) != 0) {
                    return -Double.compare(ruleADT.getSupportRule(), ruleADT2.getSupportRule());
                }
                if (Integer.compare(ruleADT.size(), ruleADT2.size()) != 0) {
                    return Integer.compare(ruleADT.size(), ruleADT2.size());
                }
                for (int i = 0; i < ruleADT.size(); i++) {
                    short shortValue = ruleADT.getAntecedent().get(i).shortValue();
                    short shortValue2 = ruleADT2.getAntecedent().get(i).shortValue();
                    if (Integer.compare(shortValue, shortValue2) != 0) {
                        return Integer.compare(shortValue, shortValue2);
                    }
                }
                return Integer.compare(ruleADT.getKlass(), ruleADT2.getKlass());
            }
        });
        List<RuleADT> removeRedundant = removeRedundant(list);
        for (int i = 0; i < this.training.getInstances().size(); i++) {
            Instance instance = this.training.getInstances().get(i);
            Short[] items = instance.getItems();
            boolean z = false;
            for (int i2 = 0; i2 < removeRedundant.size() && !z; i2++) {
                RuleADT ruleADT = removeRedundant.get(i2);
                if (ruleADT.matching(items)) {
                    z = true;
                    ruleADT.addCoveredInstance(Integer.valueOf(i));
                    if (ruleADT.getKlass() == instance.getKlass().shortValue()) {
                        ruleADT.incrementHits();
                    } else {
                        ruleADT.incrementMisses();
                    }
                }
            }
        }
        ADNode aDNode = new ADNode(extractDefaultRule());
        for (int size = removeRedundant.size() - 1; size >= 0; size--) {
            ADNode aDNode2 = aDNode;
            RuleADT ruleADT2 = removeRedundant.get(size);
            while (true) {
                ADNode isChild = aDNode2.isChild(ruleADT2);
                if (isChild == null) {
                    break;
                } else {
                    aDNode2 = isChild;
                }
            }
            ADNode aDNode3 = new ADNode(ruleADT2);
            aDNode3.parent = aDNode2;
            aDNode2.childs.add(aDNode3);
        }
        prune(aDNode);
        this.rules = transformTreeToRules(aDNode);
    }

    private List<Rule> transformTreeToRules(ADNode aDNode) {
        ArrayList arrayList = new ArrayList();
        for (int size = aDNode.childs.size() - 1; size >= 0; size--) {
            arrayList.addAll(transformTreeToRules(aDNode.childs.get(size)));
        }
        if (aDNode.rule.getMerit() >= this.minMerit) {
            arrayList.add(aDNode.rule);
        }
        return arrayList;
    }

    private void prune(ADNode aDNode) {
        if (aDNode == null || aDNode.childs.isEmpty()) {
            return;
        }
        Iterator<ADNode> it = aDNode.childs.iterator();
        while (it.hasNext()) {
            prune(it.next());
        }
        ADNode aDNode2 = new ADNode(aDNode);
        double calculatePessimisticErrorEstimate = calculatePessimisticErrorEstimate(aDNode2);
        double pessimisticErrorEstimate = aDNode.rule.getPessimisticErrorEstimate();
        Iterator<ADNode> it2 = aDNode.childs.iterator();
        while (it2.hasNext()) {
            pessimisticErrorEstimate += it2.next().rule.getPessimisticErrorEstimate();
        }
        if (calculatePessimisticErrorEstimate < pessimisticErrorEstimate) {
            aDNode.childs.clear();
            aDNode.rule = aDNode2.rule;
        }
    }

    private double calculatePessimisticErrorEstimate(ADNode aDNode) {
        Iterator<ADNode> it = aDNode.childs.iterator();
        while (it.hasNext()) {
            for (Integer num : it.next().rule.getCoveredInstances()) {
                Instance instance = this.training.getInstances().get(num.intValue());
                if (aDNode.rule.matching(instance.getItems())) {
                    aDNode.rule.addCoveredInstance(num);
                    if (aDNode.rule.getKlass() == instance.getKlass().shortValue()) {
                        aDNode.rule.incrementHits();
                    } else {
                        aDNode.rule.incrementMisses();
                    }
                }
            }
        }
        return aDNode.rule.getPessimisticErrorEstimate();
    }

    private RuleADT extractDefaultRule() {
        return new RuleADT(((Short) ((Map.Entry) Collections.max(this.training.getMapClassToFrequency().entrySet(), Comparator.comparingLong((v0) -> {
            return v0.getValue();
        }))).getKey()).shortValue());
    }

    private List<RuleADT> removeRedundant(List<RuleADT> list) {
        ArrayList arrayList = new ArrayList();
        for (RuleADT ruleADT : list) {
            boolean z = true;
            for (int i = 0; i < arrayList.size() && z; i++) {
                if (ArraysAlgos.containsOrEquals(ruleADT.getAntecedent(), ((RuleADT) arrayList.get(i)).getAntecedent())) {
                    z = false;
                }
            }
            if (z) {
                arrayList.add(ruleADT);
            }
        }
        return arrayList;
    }
}
