/*
 * Decompiled with CFR 0.152.
 */
package de.uni_freiburg.informatik.ultimate.util.datastructures;

import de.uni_freiburg.informatik.ultimate.util.datastructures.CrossProducts;
import de.uni_freiburg.informatik.ultimate.util.datastructures.DataStructureUtils;
import de.uni_freiburg.informatik.ultimate.util.datastructures.EqualityStatus;
import de.uni_freiburg.informatik.ultimate.util.datastructures.ImmutableSet;
import de.uni_freiburg.informatik.ultimate.util.datastructures.UnionFind;
import de.uni_freiburg.informatik.ultimate.util.datastructures.relation.HashRelation;
import de.uni_freiburg.informatik.ultimate.util.datastructures.relation.Triple;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

public class ThreeValuedEquivalenceRelation<E> {
    private final UnionFind<E> mUnionFind;
    private final HashRelation<E, E> mDisequalities;
    private boolean mIsInconsistent;

    public ThreeValuedEquivalenceRelation() {
        this.mUnionFind = new UnionFind();
        this.mDisequalities = new HashRelation();
        this.mIsInconsistent = false;
    }

    public ThreeValuedEquivalenceRelation(Comparator<E> elementComparator) {
        assert (elementComparator != null) : "use other constructor in this case!";
        this.mUnionFind = new UnionFind<E>(elementComparator);
        this.mDisequalities = new HashRelation();
        this.mIsInconsistent = false;
    }

    public ThreeValuedEquivalenceRelation(ThreeValuedEquivalenceRelation<E> tver) {
        this.mUnionFind = tver.mUnionFind.clone();
        this.mDisequalities = new HashRelation(tver.mDisequalities);
        this.mIsInconsistent = tver.mIsInconsistent;
        assert (this.sanityCheck());
    }

    public ThreeValuedEquivalenceRelation(UnionFind<E> newPartition, HashRelation<E, E> newDisequalities) {
        this.mUnionFind = newPartition.clone();
        this.mDisequalities = new HashRelation(newDisequalities);
        this.mIsInconsistent = false;
        assert (this.sanityCheck());
    }

    public boolean addElement(E elem) {
        if (this.mUnionFind.find(elem) == null) {
            this.mUnionFind.findAndConstructEquivalenceClassIfNeeded(elem);
            return true;
        }
        return false;
    }

    public E removeElement(E elem, E newRepChoice) {
        assert (newRepChoice == null || this.getRepresentative(elem) == this.getRepresentative(newRepChoice));
        E rep = this.mUnionFind.find(elem);
        HashSet<E> equivalenceClassCopy = new HashSet<E>(this.mUnionFind.getEquivalenceClassMembers(elem));
        this.mUnionFind.remove(elem, newRepChoice);
        if (rep != elem) {
            assert (this.getRepresentative(rep) == rep);
            return rep;
        }
        equivalenceClassCopy.remove(elem);
        if (equivalenceClassCopy.isEmpty()) {
            this.mDisequalities.removeDomainElement(elem);
            this.mDisequalities.removeRangeElement(elem);
            assert (this.sanityCheck());
            return null;
        }
        assert (rep == elem);
        E newRep = newRepChoice == null ? this.mUnionFind.find(equivalenceClassCopy.iterator().next()) : newRepChoice;
        assert (newRep != null);
        this.mDisequalities.replaceDomainElement(elem, newRep);
        this.mDisequalities.replaceRangeElement(elem, newRep);
        assert (this.sanityCheck());
        assert (this.getRepresentative(newRep) == newRep) : "the returned element must be a representative, " + newRep + " is its own rep, but " + this.getRepresentative(newRep) + " is.";
        return newRep;
    }

    public boolean reportEquality(E elem1, E elem2) {
        if (this.mIsInconsistent) {
            throw new IllegalStateException();
        }
        E oldRep1 = this.mUnionFind.find(elem1);
        if (oldRep1 == null) {
            throw new IllegalArgumentException("unknown element " + elem1);
        }
        E oldRep2 = this.mUnionFind.find(elem2);
        if (oldRep2 == null) {
            throw new IllegalArgumentException("unknown element " + elem2);
        }
        if (oldRep1 == oldRep2) {
            return false;
        }
        if (this.getEqualityStatus(elem1, elem2) == EqualityStatus.NOT_EQUAL) {
            this.reportInconsistency();
            return true;
        }
        this.mUnionFind.union(elem1, elem2);
        if (this.isRepresentative(oldRep1)) {
            assert (this.mUnionFind.find(elem2) == oldRep1);
            this.mDisequalities.replaceDomainElement(oldRep2, oldRep1);
            this.mDisequalities.replaceRangeElement(oldRep2, oldRep1);
        } else {
            assert (this.mUnionFind.find(elem1) == oldRep2);
            this.mDisequalities.replaceDomainElement(oldRep1, oldRep2);
            this.mDisequalities.replaceRangeElement(oldRep1, oldRep2);
        }
        assert (this.sanityCheck());
        return true;
    }

    private void reportInconsistency() {
        this.mIsInconsistent = true;
    }

    public boolean reportDisequality(E elem1, E elem2) {
        if (this.mIsInconsistent) {
            throw new IllegalStateException();
        }
        E rep1 = this.mUnionFind.find(elem1);
        if (rep1 == null) {
            throw new IllegalArgumentException("unknown element " + elem1);
        }
        E rep2 = this.mUnionFind.find(elem2);
        if (rep2 == null) {
            throw new IllegalArgumentException("unknown element " + elem2);
        }
        if (this.getEqualityStatus(rep1, rep2) == EqualityStatus.NOT_EQUAL) {
            return false;
        }
        if (rep1 == rep2) {
            this.reportInconsistency();
            return true;
        }
        this.mDisequalities.addPair(rep1, rep2);
        assert (this.sanityCheck());
        return true;
    }

    public E getRepresentativeAndAddElementIfNeeded(E elem) {
        return this.mUnionFind.findAndConstructEquivalenceClassIfNeeded(elem);
    }

    public E getRepresentative(E elem) {
        return this.mUnionFind.find(elem);
    }

    public boolean isRepresentative(E elem) {
        if (!this.getAllElements().contains(elem)) {
            throw new IllegalArgumentException("only call this for elements that are present!");
        }
        return this.getRepresentative(elem) == elem;
    }

    public boolean isInconsistent() {
        return this.mIsInconsistent;
    }

    public EqualityStatus getEqualityStatus(E elem1, E elem2) {
        if (this.mIsInconsistent) {
            throw new IllegalStateException("Cannot get equality status from inconsistent " + this.getClass().getSimpleName());
        }
        E elem1Rep = this.mUnionFind.find(elem1);
        if (elem1Rep == null) {
            throw new IllegalArgumentException("Unknown element: " + elem1);
        }
        E elem2Rep = this.mUnionFind.find(elem2);
        if (elem2Rep == null) {
            throw new IllegalArgumentException("Unknown element: " + elem2);
        }
        if (elem1Rep.equals(elem2Rep)) {
            return EqualityStatus.EQUAL;
        }
        if (this.mDisequalities.containsPair(elem1Rep, elem2Rep) || this.mDisequalities.containsPair(elem2Rep, elem1Rep)) {
            return EqualityStatus.NOT_EQUAL;
        }
        return EqualityStatus.UNKNOWN;
    }

    public boolean sanityCheck() {
        if (!this.mUnionFind.sanityCheck()) {
            return false;
        }
        for (Map.Entry en : this.mDisequalities.getSetOfPairs()) {
            if (en.getKey() == null) {
                return false;
            }
            if (en.getValue() != null) continue;
            return false;
        }
        for (Map.Entry en : this.mDisequalities.getSetOfPairs()) {
            if (!this.isRepresentative(en.getKey())) {
                return false;
            }
            if (this.isRepresentative(en.getValue())) continue;
            return false;
        }
        return true;
    }

    public Collection<E> getAllRepresentatives() {
        return this.mUnionFind.getAllRepresentatives();
    }

    public Collection<Set<E>> getAllEquivalenceClasses() {
        return this.mUnionFind.getAllEquivalenceClasses();
    }

    public String toString() {
        if (this.isTautological()) {
            return "True";
        }
        if (this.isInconsistent()) {
            return "False";
        }
        StringBuilder sb = new StringBuilder();
        sb.append("Equivalences: ");
        sb.append(this.mUnionFind);
        sb.append("\n");
        sb.append("Non-Equivalences: ");
        sb.append(this.mDisequalities);
        return sb.toString();
    }

    public Set<E> getAllElements() {
        return this.mUnionFind.getAllElements();
    }

    public Set<E> getRepresentativesUnequalTo(E rep) {
        assert (this.isRepresentative(rep));
        HashSet<Object> result = new HashSet<Object>();
        result.addAll(this.mDisequalities.getImage(rep));
        for (Object domEl : this.mDisequalities.getDomain()) {
            if (!this.mDisequalities.getImage(domEl).contains(rep)) continue;
            result.add(domEl);
        }
        return result;
    }

    public Set<E> getEquivalenceClass(E elem) {
        return this.mUnionFind.getEquivalenceClassMembers(elem);
    }

    public ThreeValuedEquivalenceRelation<E> join(ThreeValuedEquivalenceRelation<E> other) {
        UnionFind<E> newPartition = UnionFind.intersectPartitionBlocks(this.mUnionFind, other.mUnionFind).getFirst();
        return new ThreeValuedEquivalenceRelation<E>(newPartition, ThreeValuedEquivalenceRelation.xJoinDisequalities(this, other, newPartition, true));
    }

    public ThreeValuedEquivalenceRelation<E> meet(ThreeValuedEquivalenceRelation<E> other) {
        UnionFind<E> newPartition = UnionFind.unionPartitionBlocks(this.mUnionFind, other.mUnionFind);
        return new ThreeValuedEquivalenceRelation<E>(newPartition, ThreeValuedEquivalenceRelation.xJoinDisequalities(this, other, newPartition, false));
    }

    public Triple<UnionFind<E>, HashRelation<E, E>, HashRelation<E, E>> joinPartitions(ThreeValuedEquivalenceRelation<E> other) {
        return UnionFind.intersectPartitionBlocks(this.mUnionFind, other.mUnionFind);
    }

    private static <E> HashRelation<E, E> xJoinDisequalities(ThreeValuedEquivalenceRelation<E> tver1, ThreeValuedEquivalenceRelation<E> tver2, UnionFind<E> newElemPartition, boolean conjoin) {
        HashRelation result = new HashRelation();
        for (Map.Entry pair : CrossProducts.binarySelectiveCrossProduct(newElemPartition.getAllRepresentatives(), false, false)) {
            boolean addDisequality;
            if (conjoin) {
                addDisequality = tver1.getEqualityStatus(pair.getKey(), pair.getValue()) == EqualityStatus.NOT_EQUAL && tver2.getEqualityStatus(pair.getKey(), pair.getValue()) == EqualityStatus.NOT_EQUAL;
            } else {
                boolean bl = addDisequality = tver1.getEqualityStatus(pair.getKey(), pair.getValue()) == EqualityStatus.NOT_EQUAL || tver2.getEqualityStatus(pair.getKey(), pair.getValue()) == EqualityStatus.NOT_EQUAL;
            }
            if (!addDisequality) continue;
            result.addPair(pair.getKey(), pair.getValue());
        }
        return result;
    }

    public Map<E, E> getSupportingEqualities() {
        HashMap<E, Object> result = new HashMap<E, Object>();
        for (Set<E> eqc : this.mUnionFind.getAllEquivalenceClasses()) {
            Object lastElement = null;
            for (E e : eqc) {
                if (lastElement != null) {
                    result.put(e, lastElement);
                }
                lastElement = e;
            }
        }
        return result;
    }

    public HashRelation<E, E> getDisequalities() {
        assert (!this.mDisequalities.getSetOfPairs().stream().anyMatch(pr -> pr.getValue() == null));
        return new HashRelation(this.mDisequalities);
    }

    public boolean isTautological() {
        return this.getSupportingEqualities().isEmpty() && this.mDisequalities.isEmpty();
    }

    public void transformElements(Function<E, E> transformer) {
        this.mUnionFind.transformElements(transformer);
        HashRelation disequalitiesCopy = new HashRelation(this.mDisequalities);
        for (Map.Entry pair : disequalitiesCopy) {
            this.mDisequalities.removePair(pair.getKey(), pair.getValue());
            this.mDisequalities.addPair(transformer.apply(pair.getKey()), transformer.apply(pair.getValue()));
        }
        assert (this.sanityCheck());
    }

    public ThreeValuedEquivalenceRelation<E> filterAndKeepOnlyConstraintsThatIntersectWith(Set<E> elems) {
        UnionFind<Object> newUf = this.mUnionFind.getElementComparator() != null ? new UnionFind<Object>(this.mUnionFind.getElementComparator()) : new UnionFind();
        for (E elem : elems) {
            if (newUf.find(elem) != null || this.mUnionFind.find(elem) == null) continue;
            ImmutableSet<E> elemEqc = this.mUnionFind.getEquivalenceClassMembers(elem);
            newUf.addEquivalenceClass(elemEqc, this.mUnionFind.find(elem));
        }
        HashRelation<E, E> newDisequalities = new HashRelation<E, E>();
        for (Map.Entry deq : this.mDisequalities.getSetOfPairs()) {
            if (!DataStructureUtils.getSomeCommonElement(this.getEquivalenceClass(deq.getKey()), elems).isPresent() && !DataStructureUtils.getSomeCommonElement(this.getEquivalenceClass(deq.getValue()), elems).isPresent()) continue;
            newDisequalities.addPair(newUf.findAndConstructEquivalenceClassIfNeeded(deq.getKey()), newUf.findAndConstructEquivalenceClassIfNeeded(deq.getValue()));
        }
        return new ThreeValuedEquivalenceRelation<E>(newUf, newDisequalities);
    }

    public ThreeValuedEquivalenceRelation<E> projectTo(Set<E> elems) {
        UnionFind<E> newUf = this.mUnionFind.getElementComparator() != null ? new UnionFind<E>(this.mUnionFind.getElementComparator()) : new UnionFind();
        for (E elem : elems) {
            if (newUf.find(elem) != null || this.mUnionFind.find(elem) == null) continue;
            ImmutableSet<E> elemEqc = this.mUnionFind.getEquivalenceClassMembers(elem);
            newUf.addEquivalenceClass(ImmutableSet.of(DataStructureUtils.intersection(elemEqc, elems)));
        }
        HashRelation<E, E> newDisequalities = new HashRelation<E, E>();
        for (Map.Entry deq : this.mDisequalities.getSetOfPairs()) {
            Optional<E> rhsRep;
            Optional<E> lhsRep = DataStructureUtils.getSomeCommonElement(elems, this.getEquivalenceClass(deq.getKey()));
            if (!lhsRep.isPresent() || !(rhsRep = DataStructureUtils.getSomeCommonElement(elems, this.getEquivalenceClass(deq.getValue()))).isPresent()) continue;
            newDisequalities.addPair(newUf.find(lhsRep.get()), newUf.find(rhsRep.get()));
        }
        return new ThreeValuedEquivalenceRelation<E>(newUf, newDisequalities);
    }

    public boolean isConstrained(E elem) {
        if (this.mUnionFind.find(elem) == null) {
            throw new IllegalArgumentException();
        }
        if (this.getEquivalenceClass(elem).size() > 1) {
            return true;
        }
        if (this.mDisequalities.getImage(elem).size() > 0) {
            return true;
        }
        for (Map.Entry en : this.mDisequalities.getSetOfPairs()) {
            if (!en.getValue().equals(elem)) continue;
            return true;
        }
        return false;
    }

    public void removeDisequality(E elem1, E elem2) {
        this.mDisequalities.removePair(elem1, elem2);
        this.mDisequalities.removePair(elem2, elem1);
    }
}

