package ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.CPT.CPTPlus;

import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.database.Item;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.database.Sequence;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Paramable;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Predictor;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:ca/pfv/spmf/algorithms/sequenceprediction/ipredict/predictor/CPT/CPTPlus/CPTPlusPredictor.class */
public class CPTPlusPredictor extends Predictor implements Serializable {
    public PredictionTree Root;
    public Map<Integer, PredictionTree> LT;
    public Map<Integer, Bitvector> II;
    protected CPTHelper helper;
    protected long nodeNumber;
    private boolean CCF;
    private boolean CBS;
    public Encoder encoder;
    protected boolean seqEncoding;
    public Paramable parameters;
    private String TAG;
    private Map<Integer, Float> lastCountTable;

    public CPTPlusPredictor() {
        this.CCF = false;
        this.CBS = true;
        this.TAG = "CPT+";
        this.lastCountTable = null;
        this.Root = new PredictionTree();
        this.LT = new HashMap();
        this.II = new HashMap();
        this.nodeNumber = 0L;
        this.parameters = new Paramable();
        this.seqEncoding = false;
        this.helper = new CPTHelper(this);
    }

    public CPTPlusPredictor(String str) {
        this();
        this.TAG = str;
    }

    public CPTPlusPredictor(String str, String str2) {
        this(str);
        this.parameters.setParameter(str2);
    }

    @Override // ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Predictor
    public String getTAG() {
        return this.TAG;
    }

    @Override // ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Predictor
    public Boolean Train(List<Sequence> list) {
        this.Root = new PredictionTree();
        this.LT = new HashMap();
        this.II = new HashMap();
        this.encoder = new Encoder();
        this.helper.setEncoded(this.encoder);
        this.nodeNumber = 0L;
        int i = 0;
        FIFRaw fIFRaw = new FIFRaw();
        if (this.parameters.paramBoolOrDefault("CCF", this.CCF)) {
            Iterator<List<Item>> it = fIFRaw.findFrequentItemsets(list, this.parameters.paramInt("CCFmin").intValue(), this.parameters.paramInt("CCFmax").intValue(), this.parameters.paramInt("CCFsup").intValue()).iterator();
            while (it.hasNext()) {
                this.encoder.addEntry(it.next());
            }
        }
        for (Sequence sequence : list) {
            if (this.parameters.paramInt("splitMethod").intValue() > 0) {
                sequence = this.helper.keepLastItems(sequence, this.parameters.paramInt("splitLength").intValue());
            }
            Sequence encode = this.encoder.encode(new Sequence(sequence));
            PredictionTree predictionTree = this.Root;
            for (Item item : encode.getItems()) {
                for (Item item2 : this.encoder.getEntry(item.val.intValue())) {
                    if (!this.II.containsKey(item2.val)) {
                        this.II.put(item2.val, new Bitvector());
                    }
                    this.II.get(item2.val).setBit(i);
                }
                if (predictionTree.hasChild(item).booleanValue()) {
                    predictionTree = predictionTree.getChild(item);
                } else {
                    predictionTree.addChild(item);
                    this.nodeNumber++;
                    predictionTree = predictionTree.getChild(item);
                }
            }
            this.LT.put(Integer.valueOf(i), predictionTree);
            i++;
        }
        if (this.parameters.paramBoolOrDefault("CBS", this.CBS)) {
            pathCollapse();
        }
        return true;
    }

    @Override // ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Predictor
    public Sequence Predict(Sequence sequence) {
        CountTable predictionByActiveNoiseReduction = predictionByActiveNoiseReduction(this.helper.removeUnseenItems(sequence));
        Sequence bestSequence = predictionByActiveNoiseReduction.getBestSequence(1);
        this.lastCountTable = predictionByActiveNoiseReduction.getTable();
        return bestSequence;
    }

    protected CountTable predictionByActiveNoiseReduction(Sequence sequence) {
        HashSet hashSet = new HashSet();
        LinkedList linkedList = new LinkedList();
        linkedList.add(sequence);
        int size = 1 + ((int) (sequence.size() * this.parameters.paramDouble("minPredictionRatio").doubleValue()));
        double doubleValue = this.parameters.paramDouble("noiseRatio").doubleValue();
        int size2 = sequence.size();
        CountTable countTable = new CountTable(this.helper);
        countTable.update((Item[]) sequence.getItems().toArray(new Item[0]), sequence.size());
        int i = countTable.getBestSequence(1).size() > 0 ? 0 + 1 : 0;
        while (true) {
            Sequence sequence2 = (Sequence) linkedList.poll();
            if (sequence2 == null || i >= size) {
                break;
            }
            if (!hashSet.contains(sequence2)) {
                hashSet.add(sequence2);
                for (Item item : getNoise(sequence2, doubleValue)) {
                    Sequence m84clone = sequence2.m84clone();
                    int i2 = 0;
                    while (true) {
                        if (i2 >= m84clone.getItems().size()) {
                            break;
                        }
                        if (m84clone.getItems().get(i2).equals(item)) {
                            m84clone.getItems().remove(i2);
                            break;
                        }
                        i2++;
                    }
                    if (m84clone.size() > 1) {
                        linkedList.add(m84clone);
                    }
                    if (countTable.update((Item[]) m84clone.getItems().toArray(new Item[0]), size2) > 0 && countTable.getBestSequence(1).size() > 0) {
                        i++;
                    }
                }
            }
        }
        return countTable;
    }

    protected List<Item> getNoise(Sequence sequence, double d) {
        int floor = (int) Math.floor(sequence.size() * d);
        if (floor > 0) {
            return ((List) sequence.getItems().stream().sorted((item, item2) -> {
                return Integer.compare(this.II.get(item2.val).cardinality(), this.II.get(item.val).cardinality());
            }).collect(Collectors.toList())).subList(sequence.size() - floor, sequence.size());
        }
        int i = Integer.MAX_VALUE;
        int i2 = -1;
        for (Item item3 : sequence.getItems()) {
            if (this.II.get(item3.val).cardinality() < i) {
                i = this.II.get(item3.val).cardinality();
                i2 = item3.val.intValue();
            }
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Item(Integer.valueOf(i2)));
        return arrayList;
    }

    protected void pathCollapse() {
        int i = 0;
        Iterator<Map.Entry<Integer, PredictionTree>> it = this.LT.entrySet().iterator();
        while (it.hasNext()) {
            PredictionTree value = it.next().getValue();
            PredictionTree predictionTree = null;
            ArrayList arrayList = new ArrayList();
            int i2 = 0;
            boolean z = true;
            if (value.getChildren().size() == 0) {
                while (z) {
                    if (value.getChildren().size() > 1 || value == null) {
                        if (i2 != 1) {
                            value.Item = new Item(this.encoder.getIdorAdd(arrayList));
                            value.Parent = value;
                            value.removeChild(predictionTree.Item);
                            value.addChild(value);
                            i += i2 - 1;
                        }
                        z = false;
                    } else {
                        List<Item> entry = this.encoder.getEntry(value.Item.val.intValue());
                        ArrayList arrayList2 = arrayList;
                        arrayList = new ArrayList();
                        arrayList.addAll(entry);
                        arrayList.addAll(arrayList2);
                        value.getChildren().clear();
                        i2++;
                        predictionTree = value;
                        value = value.Parent;
                    }
                }
            }
        }
        this.nodeNumber -= i;
    }

    @Override // ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Predictor
    public long size() {
        return this.nodeNumber;
    }

    @Override // ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Predictor
    public float memoryUsage() {
        return ((float) (this.nodeNumber * 3 * 4)) + ((float) (this.II.size() * (Math.ceil(this.LT.size() / 8) + 4.0d))) + (this.LT.size() * 2 * 4);
    }

    public Map<Integer, Float> getCountTable() {
        return this.lastCountTable;
    }
}
