/*
 * Decompiled with CFR 0.152.
 */
package de.uni_freiburg.informatik.ultimate.smtinterpol.interpolate;

import de.uni_freiburg.informatik.ultimate.logic.AnnotatedTerm;
import de.uni_freiburg.informatik.ultimate.logic.Annotation;
import de.uni_freiburg.informatik.ultimate.logic.ApplicationTerm;
import de.uni_freiburg.informatik.ultimate.logic.FunctionSymbol;
import de.uni_freiburg.informatik.ultimate.logic.Rational;
import de.uni_freiburg.informatik.ultimate.logic.Script;
import de.uni_freiburg.informatik.ultimate.logic.Sort;
import de.uni_freiburg.informatik.ultimate.logic.Term;
import de.uni_freiburg.informatik.ultimate.logic.TermTransformer;
import de.uni_freiburg.informatik.ultimate.logic.TermVariable;
import de.uni_freiburg.informatik.ultimate.logic.Theory;
import de.uni_freiburg.informatik.ultimate.smtinterpol.LogProxy;
import de.uni_freiburg.informatik.ultimate.smtinterpol.interpolate.Interpolator;
import de.uni_freiburg.informatik.ultimate.smtinterpol.interpolate.InterpolatorAffineTerm;
import de.uni_freiburg.informatik.ultimate.smtinterpol.interpolate.InterpolatorAtomInfo;
import de.uni_freiburg.informatik.ultimate.smtinterpol.interpolate.LAInterpolator;
import de.uni_freiburg.informatik.ultimate.smtinterpol.interpolate.SymbolChecker;
import de.uni_freiburg.informatik.ultimate.smtinterpol.interpolate.SymbolCollector;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class InterpolantChecker {
    Interpolator mInterpolator;
    Script mCheckingSolver;
    Set<FunctionSymbol> mGlobals;

    public InterpolantChecker(Interpolator interpolator, Script checkingSolver) {
        this.mInterpolator = interpolator;
        this.mCheckingSolver = checkingSolver;
    }

    private Term fixupAndLet(Term interpolant, final HashMap<TermVariable, Term> fixedEQs, final HashMap<TermVariable, Term> auxMap) {
        TermTransformer substitutor = new TermTransformer(){

            @Override
            public void convertApplicationTerm(ApplicationTerm appTerm, Term[] newArgs) {
                TermVariable tv;
                Term replacement;
                FunctionSymbol func = appTerm.getFunction();
                if (fixedEQs != null && func.isIntern() && func.getName().equals("@EQ") && (replacement = (Term)fixedEQs.get(tv = (TermVariable)appTerm.getParameters()[0])) != null) {
                    Term sharedValue = newArgs[1];
                    this.setResult(InterpolantChecker.this.mInterpolator.substitute(replacement, tv, sharedValue));
                    return;
                }
                super.convertApplicationTerm(appTerm, newArgs);
            }

            @Override
            public void convert(Term term) {
                Term replacement;
                if (LAInterpolator.isLATerm(term)) {
                    term = ((AnnotatedTerm)term).getSubterm();
                }
                if (auxMap != null && term instanceof TermVariable && (replacement = (Term)auxMap.get(term)) != null) {
                    this.setResult(replacement);
                    return;
                }
                super.convert(term);
            }
        };
        return substitutor.transform(interpolant);
    }

    private Term fixVars(Term interpolant, HashMap<TermVariable, Term> varToTerm) {
        for (TermVariable tv : interpolant.getFreeVars()) {
            Interpolator interpolator = this.mInterpolator;
            Objects.requireNonNull(interpolator);
            Interpolator.TermSubstitutor ipolator = new Interpolator.TermSubstitutor(interpolator, tv, varToTerm.get(tv));
            interpolant = ipolator.transform(interpolant);
        }
        return interpolant;
    }

    private Term purifyAndFix(Term literal, HashMap<TermVariable, Term> varToTerm, HashMap<TermVariable, Term> varToFreshTerm) {
        for (Map.Entry<TermVariable, Term> e : varToFreshTerm.entrySet()) {
            Term term = varToTerm.get(e.getKey());
            assert (term != null);
            Interpolator interpolator = this.mInterpolator;
            Objects.requireNonNull(interpolator);
            Interpolator.TermSubstitutor ipolator = new Interpolator.TermSubstitutor(interpolator, term, e.getValue());
            literal = ipolator.transform(literal);
        }
        return literal;
    }

    public void checkInductivity(Term[] literals, Term[] ipls) {
        LogProxy logger = this.mInterpolator.getLogger();
        Theory theory = this.mInterpolator.mTheory;
        int old = logger.getLoglevel();
        logger.setLoglevel(2);
        this.mCheckingSolver.push(1);
        HashMap[] auxMaps = new HashMap[ipls.length];
        HashMap<TermVariable, Term> purVarToTerm = this.mInterpolator.mPurifyDefinitions;
        HashSet<TermVariable> activeVars = new HashSet<TermVariable>();
        HashMap<TermVariable, Term> purVarToFreshTerm = new HashMap<TermVariable, Term>();
        for (Map.Entry<TermVariable, Term> e : purVarToTerm.entrySet()) {
            TermVariable tv = e.getKey();
            if (purVarToFreshTerm.containsKey(tv)) continue;
            String name = ".check" + tv.getName();
            this.mCheckingSolver.declareFun(name, new Sort[0], tv.getSort());
            Term constTerm = this.mCheckingSolver.term(name, new Term[0]);
            purVarToFreshTerm.put(tv, constTerm);
        }
        assert (purVarToTerm.size() == purVarToFreshTerm.size());
        for (int i = 0; i < ipls.length; ++i) {
            activeVars.addAll(Arrays.asList(ipls[i].getFreeVars()));
        }
        for (Term lit : literals) {
            Term atom = this.mInterpolator.getAtom(lit);
            InterpolatorAtomInfo atomTermInfo = this.mInterpolator.getAtomTermInfo(atom);
            Interpolator.LitInfo info = this.mInterpolator.getAtomOccurenceInfo(atom);
            TermVariable tv = info.mMixedVar;
            if (tv == null) continue;
            Term auxTerm = null;
            for (int part = 0; part < ipls.length; ++part) {
                Term partAuxTerm;
                String name;
                if (!info.isMixed(part)) continue;
                if (atomTermInfo.isCCEquality()) {
                    if (auxTerm == null) {
                        name = ".check." + tv.getName();
                        this.mCheckingSolver.declareFun(name, new Sort[0], tv.getSort());
                        auxTerm = this.mCheckingSolver.term(name, new Term[0]);
                    }
                    partAuxTerm = auxTerm;
                } else {
                    name = ".check" + part + "." + tv.getName();
                    this.mCheckingSolver.declareFun(name, new Sort[0], tv.getSort());
                    partAuxTerm = this.mCheckingSolver.term(name, new Term[0]);
                }
                if (auxMaps[part] == null) {
                    auxMaps[part] = new HashMap();
                }
                auxMaps[part].put(tv, partAuxTerm);
            }
        }
        for (int part = 0; part < ipls.length + 1; ++part) {
            HashMap[] fixedEQs = new HashMap[ipls.length];
            this.mCheckingSolver.push(1);
            for (Map.Entry<String, Integer> entry : this.mInterpolator.mPartitions.entrySet()) {
                if (entry.getValue() != part) continue;
                this.mCheckingSolver.assertTerm(theory.term(entry.getKey(), new Term[0]));
            }
            for (Term lit : literals) {
                InterpolatorAffineTerm at;
                Term atom = this.mInterpolator.getAtom(lit);
                boolean isNegated = atom == lit;
                InterpolatorAtomInfo atomTermInfo = this.mInterpolator.getAtomTermInfo(atom);
                Interpolator.LitInfo occInfo = this.mInterpolator.mAtomOccurenceInfos.get(atom);
                if (occInfo.contains(part)) {
                    Term purLit = this.purifyAndFix(lit, purVarToTerm, purVarToFreshTerm);
                    this.mCheckingSolver.assertTerm(theory.not(purLit));
                    continue;
                }
                if (occInfo.isBLocal(part) || occInfo.isALocalInSomeChild(part)) continue;
                if (atomTermInfo.isCCEquality()) {
                    ApplicationTerm cceq = atomTermInfo.getEquality();
                    int firstMixedChild = -1;
                    int secondMixedChild = -1;
                    int child = part - 1;
                    while (child >= this.mInterpolator.mStartOfSubtrees[part]) {
                        if (occInfo.isMixed(child)) {
                            if (firstMixedChild < 0) {
                                firstMixedChild = child;
                            } else {
                                assert (secondMixedChild < 0);
                                secondMixedChild = child;
                            }
                        }
                        child = this.mInterpolator.mStartOfSubtrees[child] - 1;
                    }
                    if (firstMixedChild < 0) {
                        assert (occInfo.isMixed(part));
                        String op = isNegated ? "@EQ" : "=";
                        int side = occInfo.getLhsOccur().isALocal(part) ? 0 : 1;
                        Term auxvar = (Term)auxMaps[part].get(occInfo.mMixedVar);
                        this.mCheckingSolver.assertTerm(theory.term(op, auxvar, cceq.getParameters()[side]));
                        continue;
                    }
                    if (!occInfo.isMixed(part)) {
                        Term auxvar = (Term)auxMaps[firstMixedChild].get(occInfo.mMixedVar);
                        if (secondMixedChild < 0) {
                            int side;
                            int n = side = occInfo.getLhsOccur().isALocal(firstMixedChild) ? 1 : 0;
                            if (isNegated) {
                                this.mCheckingSolver.assertTerm(theory.not(theory.term("@EQ", auxvar, cceq.getParameters()[side])));
                                continue;
                            }
                            this.mCheckingSolver.assertTerm(theory.term("=", auxvar, cceq.getParameters()[side]));
                            continue;
                        }
                        if (fixedEQs[secondMixedChild] == null) {
                            fixedEQs[secondMixedChild] = new HashMap();
                        }
                        fixedEQs[secondMixedChild].put(occInfo.mMixedVar, theory.not(theory.term("@EQ", auxvar, occInfo.mMixedVar)));
                        continue;
                    }
                    assert (firstMixedChild >= 0 && secondMixedChild < 0 && occInfo.isMixed(part));
                    continue;
                }
                if (atomTermInfo.isLAEquality() && isNegated) {
                    Term auxvar;
                    at = new InterpolatorAffineTerm();
                    int firstMixedChild = -1;
                    int child = part - 1;
                    while (child >= this.mInterpolator.mStartOfSubtrees[part]) {
                        if (occInfo.isMixed(child)) {
                            at.add(Rational.MONE, occInfo.getAPart(child));
                            if (firstMixedChild < 0) {
                                firstMixedChild = child;
                            } else {
                                Term auxvar2 = (Term)auxMaps[child].get(occInfo.mMixedVar);
                                at.add(Rational.ONE, auxvar2);
                                if (fixedEQs[child] == null) {
                                    fixedEQs[child] = new HashMap();
                                }
                                fixedEQs[child].put(occInfo.mMixedVar, theory.term("=", auxvar2, occInfo.mMixedVar));
                            }
                        }
                        child = this.mInterpolator.mStartOfSubtrees[child] - 1;
                    }
                    if (occInfo.isMixed(part)) {
                        assert (occInfo.isMixed(part));
                        at.add(Rational.ONE, occInfo.getAPart(part));
                        if (firstMixedChild < 0) {
                            auxvar = (Term)auxMaps[part].get(occInfo.mMixedVar);
                            Term aPart = at.toSMTLib(theory, atomTermInfo.isInt());
                            this.mCheckingSolver.assertTerm(theory.term("@EQ", auxvar, aPart));
                            continue;
                        }
                        auxvar = (Term)auxMaps[firstMixedChild].get(occInfo.mMixedVar);
                        at.negate();
                        at.add(Rational.ONE, occInfo.mMixedVar);
                        if (fixedEQs[part] == null) {
                            fixedEQs[part] = new HashMap();
                        }
                        Term replacement = theory.term("@EQ", auxvar, at.toSMTLib(theory, atomTermInfo.isInt()));
                        fixedEQs[part].put(occInfo.mMixedVar, replacement);
                        continue;
                    }
                    assert (firstMixedChild >= 0);
                    auxvar = (Term)auxMaps[firstMixedChild].get(occInfo.mMixedVar);
                    at.add(Rational.ONE, atomTermInfo.getAffineTerm());
                    at.negate();
                    Term bPart = at.toSMTLib(theory, atomTermInfo.isInt());
                    this.mCheckingSolver.assertTerm(theory.not(theory.term("@EQ", auxvar, bPart)));
                    continue;
                }
                at = new InterpolatorAffineTerm();
                int child = part - 1;
                while (child >= this.mInterpolator.mStartOfSubtrees[part]) {
                    if (occInfo.isMixed(child)) {
                        at.add(Rational.MONE, occInfo.getAPart(child));
                        at.add(Rational.ONE, (Term)auxMaps[child].get(occInfo.mMixedVar));
                    }
                    child = this.mInterpolator.mStartOfSubtrees[child] - 1;
                }
                if (occInfo.isMixed(part)) {
                    assert (occInfo.mMixedVar != null);
                    at.add(Rational.ONE, occInfo.getAPart(part));
                    at.add(Rational.MONE, (Term)auxMaps[part].get(occInfo.mMixedVar));
                } else {
                    InterpolatorAffineTerm lv = new InterpolatorAffineTerm(atomTermInfo.getAffineTerm());
                    if (isNegated) {
                        lv.add(atomTermInfo.getEpsilon().negate());
                    }
                    at.add(Rational.ONE, lv);
                }
                if (atomTermInfo.isBoundConstraint()) {
                    if (isNegated) {
                        at.negate();
                    }
                    this.mCheckingSolver.assertTerm(at.toLeq0(theory));
                    continue;
                }
                boolean isInt = at.isInt();
                Sort sort = theory.getSort(isInt ? "Int" : "Real", new Sort[0]);
                Term t = at.toSMTLib(theory, isInt);
                Term zero = Rational.ZERO.toTerm(sort);
                Term eqTerm = theory.term("=", t, zero);
                if (!occInfo.isMixed(part) && isNegated) {
                    eqTerm = theory.term("not", eqTerm);
                }
                this.mCheckingSolver.assertTerm(eqTerm);
            }
            int child = part - 1;
            while (child >= this.mInterpolator.mStartOfSubtrees[part]) {
                Term interpolant = this.fixupAndLet(ipls[child], fixedEQs[child], auxMaps[child]);
                interpolant = this.fixVars(interpolant, purVarToFreshTerm);
                this.mCheckingSolver.assertTerm(interpolant);
                child = this.mInterpolator.mStartOfSubtrees[child] - 1;
            }
            if (part < ipls.length) {
                Term interpolant = this.fixupAndLet(ipls[part], fixedEQs[part], auxMaps[part]);
                interpolant = this.fixVars(interpolant, purVarToFreshTerm);
                this.mCheckingSolver.assertTerm(theory.not(interpolant));
            }
            for (Map.Entry<TermVariable, Term> e : purVarToTerm.entrySet()) {
                ApplicationTerm t = (ApplicationTerm)e.getValue();
                Interpolator.Occurrence occ = this.mInterpolator.mFunctionSymbolOccurrenceInfos.get(t.getFunction());
                if (occ.contains(part) || !activeVars.contains(e.getKey())) {
                    // empty if block
                }
                Term tNew = this.fixVars(t, purVarToFreshTerm);
                this.mCheckingSolver.assertTerm(theory.term("=", tNew, purVarToFreshTerm.get(e.getKey())));
            }
            if (this.mCheckingSolver.checkSat() == Script.LBool.SAT) {
                throw new AssertionError();
            }
            this.mCheckingSolver.pop(1);
        }
        this.mCheckingSolver.pop(1);
        logger.setLoglevel(old);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void assertUnpartitionedFormulas(Collection<Term> assertions, Set<String> partitions) {
        LogProxy logger = this.mInterpolator.getLogger();
        int old = logger.getLoglevel();
        try {
            logger.setLoglevel(2);
            SymbolCollector collector = new SymbolCollector();
            block3: for (Term asserted : assertions) {
                if (asserted instanceof AnnotatedTerm) {
                    AnnotatedTerm annot = (AnnotatedTerm)asserted;
                    for (Annotation an : annot.getAnnotations()) {
                        if (":named".equals(an.getKey()) && partitions.contains(an.getValue())) continue block3;
                    }
                }
                this.mCheckingSolver.assertTerm(asserted);
                collector.collect(asserted);
            }
            this.mGlobals = collector.getSymbols();
        }
        finally {
            logger.setLoglevel(old);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public boolean checkFinalInterpolants(Map<String, Integer> partitions, Term[] interpolants) {
        boolean error = false;
        int numPartitions = interpolants.length + 1;
        SymbolCollector collector = new SymbolCollector();
        Set[] occs = new Set[numPartitions];
        for (int part = 0; part < numPartitions; ++part) {
            for (Map.Entry<String, Integer> entry : this.mInterpolator.mPartitions.entrySet()) {
                if (entry.getValue() != part) continue;
                collector.collect(this.mCheckingSolver.term(entry.getKey(), new Term[0]));
            }
            occs[part] = collector.getSymbols();
        }
        Map[] subOccurrences = new Map[numPartitions];
        for (int part = 0; part < numPartitions; ++part) {
            subOccurrences[part] = new HashMap();
            for (Object fsym : occs[part]) {
                subOccurrences[part].put(fsym, 1);
            }
            int child = part - 1;
            while (child >= this.mInterpolator.mStartOfSubtrees[part]) {
                Object fsym;
                fsym = subOccurrences[child].entrySet().iterator();
                while (fsym.hasNext()) {
                    Map.Entry entry = (Map.Entry)fsym.next();
                    Integer ival = (Integer)subOccurrences[part].get(entry.getKey());
                    if (ival == null) {
                        ival = 0;
                    }
                    subOccurrences[part].put((FunctionSymbol)entry.getKey(), ival + (Integer)entry.getValue());
                }
                child = this.mInterpolator.mStartOfSubtrees[child] - 1;
            }
        }
        LogProxy logger = this.mInterpolator.getLogger();
        int old = logger.getLoglevel();
        try {
            logger.setLoglevel(2);
            SymbolChecker checker = new SymbolChecker(this.mGlobals, subOccurrences[interpolants.length]);
            for (int part = 0; part < numPartitions; ++part) {
                Script.LBool res;
                this.mCheckingSolver.push(1);
                int child = part - 1;
                while (child >= this.mInterpolator.mStartOfSubtrees[part]) {
                    this.mCheckingSolver.assertTerm(interpolants[child]);
                    child = this.mInterpolator.mStartOfSubtrees[child] - 1;
                }
                for (Map.Entry<String, Integer> entry : this.mInterpolator.mPartitions.entrySet()) {
                    if (entry.getValue() != part) continue;
                    this.mCheckingSolver.assertTerm(this.mCheckingSolver.term(entry.getKey(), new Term[0]));
                }
                if (part != interpolants.length) {
                    if (checker.check(interpolants[part], subOccurrences[part])) {
                        logger.error("Symbol error in Interpolant %d: A-local: %s, B-local: %s.", part, checker.getALocals(), checker.getBLocals());
                        error = true;
                    }
                    this.mCheckingSolver.assertTerm(this.mCheckingSolver.term("not", interpolants[part]));
                }
                if ((res = this.mCheckingSolver.checkSat()) == Script.LBool.SAT) {
                    logger.error("Interpolant %d not inductive", part);
                    error = true;
                } else if (res == Script.LBool.UNKNOWN) {
                    logger.warn("Unable to check validity of interpolant: %s", this.mCheckingSolver.getInfo(":reason-unknown"));
                }
                this.mCheckingSolver.pop(1);
            }
        }
        finally {
            logger.setLoglevel(old);
        }
        return !error;
    }
}

