/*
 * Decompiled with CFR 0.152.
 */
package org.sosy_lab.pjbdd.zdd;

import java.util.Optional;
import java.util.concurrent.ForkJoinTask;
import java.util.function.BiFunction;
import org.sosy_lab.pjbdd.api.DD;
import org.sosy_lab.pjbdd.core.cache.Cache;
import org.sosy_lab.pjbdd.core.node.NodeManager;
import org.sosy_lab.pjbdd.util.threadpool.ParallelismManager;
import org.sosy_lab.pjbdd.zdd.ZDDAlgorithm;
import org.sosy_lab.pjbdd.zdd.ZDDSerialAlgorithm;

public class ZDDConcurrentAlgorithm<V extends DD>
extends ZDDSerialAlgorithm<V> {
    private final ParallelismManager parallelismManager;

    public ZDDConcurrentAlgorithm(Cache<Integer, Cache.CacheData> computedTable, NodeManager<V> nodeManager, ParallelismManager parallelismManager) {
        super(computedTable, nodeManager);
        this.parallelismManager = parallelismManager;
    }

    @Override
    protected V unaryShannon(V zdd, V var, ZDDAlgorithm.ZddOps op) {
        return (V)this.operationCheck(zdd, var, op).orElseGet(() -> this.asyncShannon(zdd, var, op));
    }

    private V asyncShannon(V zdd, V var, ZDDAlgorithm.ZddOps op) {
        V res = this.applyOperation(zdd, var, op);
        this.cacheBinaryItem(zdd, var, op.ordinal(), res);
        return res;
    }

    @Override
    public V union(V f1, V f2) {
        return (V)this.unionCheck(f1, f2).orElseGet(() -> this.asyncUnion(f1, f2));
    }

    private V asyncUnion(V f1, V f2) {
        V res = this.level(f1) < this.level(f2) ? this.makeNode(this.union(this.getLow(f1), f2), this.getHigh(f1), f1.getVariable()) : (this.level(f1) == this.level(f2) ? this.applyOperation(f1, f2, ZDDAlgorithm.ZddOps.UNION) : this.makeNode(this.union(f1, this.getLow(f2)), this.getHigh(f2), f2.getVariable()));
        this.cacheBinaryItem(f1, f2, ZDDAlgorithm.ZddOps.UNION.ordinal(), res);
        return res;
    }

    @Override
    public V difference(V f1, V f2) {
        return (V)this.differenceCheck(f1, f2).orElseGet(() -> this.asyncDifference(f1, f2));
    }

    private V asyncDifference(V f1, V f2) {
        V res = this.level(f1) < this.level(f2) ? this.makeNode(this.difference(this.getLow(f1), f2), this.getHigh(f1), f1.getVariable()) : (this.level(f1) == this.level(f2) ? this.applyOperation(f1, f2, ZDDAlgorithm.ZddOps.DIFF) : this.difference(f1, this.getLow(f2)));
        this.cacheBinaryItem(f1, f2, ZDDAlgorithm.ZddOps.DIFF.ordinal(), res);
        return res;
    }

    @Override
    public V intersection(V f1, V f2) {
        return (V)this.intersectsCheck(f1, f2).orElseGet(() -> this.asyncIntersect(f1, f2));
    }

    private V asyncIntersect(V f1, V f2) {
        V res = this.level(f1) < this.level(f2) ? this.intersection(this.getLow(f1), f2) : (this.level(f1) == this.level(f2) ? this.applyOperation(f1, f2, ZDDAlgorithm.ZddOps.INTSEC) : this.intersection(f1, this.getLow(f2)));
        this.cacheBinaryItem(f1, f2, ZDDAlgorithm.ZddOps.INTSEC.ordinal(), res);
        return res;
    }

    @Override
    public V product(V f1, V f2) {
        return (V)this.productCheck(f1, f2).orElseGet(() -> this.asyncProduct(f1, f2));
    }

    private V asyncProduct(V f1, V f2) {
        Object res;
        if (this.level(f1) != this.level(f2)) {
            res = this.applyOperation(f1, f2, ZDDAlgorithm.ZddOps.MUL);
        } else {
            Optional<V> optLowTask;
            Optional<V> optHighTask;
            Optional<V> optHighTmp;
            ForkJoinTask<DD> lowTask1 = null;
            ForkJoinTask<DD> highTask1 = null;
            ForkJoinTask<DD> highTask = null;
            ForkJoinTask<DD> lowTask = null;
            V lowF1 = this.getLow(f1);
            V lowF2 = this.getLow(f2);
            V highF1 = this.getHigh(f1);
            V highF2 = this.getHigh(f2);
            Optional<V> optLowTmp = this.trySerialComputation(highF1, lowF2, ZDDAlgorithm.ZddOps.MUL);
            if (optLowTmp.isEmpty()) {
                lowTask1 = this.createTask(this::asyncProduct, highF1, lowF2);
            }
            if ((optHighTmp = this.trySerialComputation(highF1, highF2, ZDDAlgorithm.ZddOps.MUL)).isEmpty()) {
                highTask1 = this.createTask(this::asyncProduct, highF1, highF2);
            }
            if ((optHighTask = this.trySerialComputation(lowF1, highF2, ZDDAlgorithm.ZddOps.MUL)).isEmpty()) {
                highTask = this.createTask(this::asyncProduct, lowF1, highF2);
            }
            if ((optLowTask = this.trySerialComputation(lowF1, lowF2, ZDDAlgorithm.ZddOps.MUL)).isEmpty()) {
                lowTask = this.createTask(this::asyncProduct, lowF1, lowF2);
            }
            DD highTmp = optHighTmp.isPresent() ? (DD)optHighTmp.get() : this.extract(highTask1);
            DD lowTmp = optLowTmp.isPresent() ? (DD)optLowTmp.get() : this.extract(lowTask1);
            DD high = optHighTask.isPresent() ? (DD)optHighTask.get() : this.extract(highTask);
            DD low = optLowTask.isPresent() ? (DD)optLowTask.get() : this.extract(lowTask);
            res = this.union(highTmp, lowTmp);
            high = this.union(res, high);
            res = this.makeNode(low, high, f1.getVariable());
        }
        this.cacheBinaryItem(f1, f2, ZDDAlgorithm.ZddOps.MUL.ordinal(), res);
        return res;
    }

    @Override
    public V division(V f1, V f2) {
        return (V)this.divisionCheck(f1, f2).orElseGet(() -> this.asyncDivision(f1, f2));
    }

    private V asyncDivision(V f1, V f2) {
        V res = this.level(f1) < this.level(f2) ? this.applyOperation(f1, f2, ZDDAlgorithm.ZddOps.DIV) : this.sequentialDivisionCase(f1, f2);
        this.cacheBinaryItem(f1, f2, ZDDAlgorithm.ZddOps.DIV.ordinal(), res);
        return res;
    }

    @Override
    public V exclude(V f1, V f2) {
        return (V)this.excludeCheck(f1, f2).orElseGet(() -> this.asyncExclude(f1, f2));
    }

    private V asyncExclude(V f1, V f2) {
        V res;
        if (this.level(f1) > this.level(f2)) {
            res = this.exclude(f1, this.getLow(f2));
        } else if (this.level(f1) < this.level(f2)) {
            res = this.applyOperation(f1, f2, ZDDAlgorithm.ZddOps.EXCLUDE);
        } else {
            Object low;
            Object high;
            if (this.followLow(this.getHigh(f2)).equals(this.base())) {
                high = this.empty();
            } else {
                V lowF2 = this.getLow(f2);
                V highF1 = this.getHigh(f1);
                V highF2 = this.getHigh(f2);
                Optional<V> lowCheck = this.excludeCheck(highF1, highF2);
                Optional<V> highCheck = this.excludeCheck(highF1, lowF2);
                if (lowCheck.isEmpty() && highCheck.isEmpty() && this.parallelismManager.canFork(this.level(f1))) {
                    ForkJoinTask<DD> lowTask = this.createTask(this::asyncExclude, highF1, highF2);
                    ForkJoinTask<DD> highTask = this.createTask(this::asyncExclude, highF1, lowF2);
                    high = this.intersection(this.extract(highTask), this.extract(lowTask));
                } else {
                    high = highCheck.orElseGet(() -> this.asyncExclude(highF1, lowF2));
                    low = lowCheck.orElseGet(() -> this.asyncExclude(highF1, highF2));
                    high = this.intersection(high, low);
                }
            }
            low = this.exclude(this.getLow(f1), this.getLow(f2));
            res = this.makeNode(low, high, f1.getVariable());
        }
        this.cacheBinaryItem(f1, f2, ZDDAlgorithm.ZddOps.EXCLUDE.ordinal(), res);
        return res;
    }

    @Override
    public V restrict(V f1, V f2) {
        return (V)this.restrictCheck(f1, f2).orElseGet(() -> this.asyncRestrict(f1, f2));
    }

    private V asyncRestrict(V f1, V f2) {
        Object res;
        if (this.level(f1) > this.level(f2)) {
            res = this.restrict(f1, this.getLow(f2));
        } else if (this.level(f1) < this.level(f2)) {
            res = this.applyOperation(f1, f2, ZDDAlgorithm.ZddOps.RESTRICT);
        } else {
            DD highRes;
            DD highTmp;
            DD lowTmp;
            Optional<DD> optHighRes;
            Optional<V> optLowRes;
            Optional<V> optHighTmp;
            ForkJoinTask<DD> lowTmpTask = null;
            ForkJoinTask<DD> highTmpTask = null;
            ForkJoinTask<DD> lowResTask = null;
            V lowF1 = this.getLow(f1);
            V lowF2 = this.getLow(f2);
            V highF1 = this.getHigh(f1);
            V highF2 = this.getHigh(f2);
            Optional<V> optLowTmp = this.trySerialComputation(highF1, lowF2, ZDDAlgorithm.ZddOps.RESTRICT);
            if (optLowTmp.isEmpty()) {
                lowTmpTask = this.createTask(this::asyncRestrict, highF1, lowF2);
            }
            if ((optHighTmp = this.trySerialComputation(highF1, highF2, ZDDAlgorithm.ZddOps.RESTRICT)).isEmpty()) {
                highTmpTask = this.createTask(this::asyncRestrict, highF1, highF2);
            }
            if ((optLowRes = this.trySerialComputation(lowF1, lowF2, ZDDAlgorithm.ZddOps.RESTRICT)).isEmpty()) {
                lowResTask = this.createTask(this::asyncRestrict, lowF1, lowF2);
            }
            if ((optHighRes = this.trySerialComputation(lowTmp = optLowTmp.isPresent() ? (DD)optLowTmp.get() : this.extract(lowTmpTask), highTmp = optHighTmp.isPresent() ? (DD)optHighTmp.get() : this.extract(highTmpTask), ZDDAlgorithm.ZddOps.UNION)).isEmpty()) {
                ForkJoinTask<DD> highResTask = this.createTask(this::asyncUnion, highTmp, lowTmp);
                highRes = this.extract(highResTask);
            } else {
                highRes = optHighRes.get();
            }
            DD lowRes = optLowRes.isPresent() ? (DD)optLowRes.get() : this.extract(lowResTask);
            res = this.makeNode(lowRes, highRes, f1.getVariable());
        }
        this.cacheBinaryItem(f1, f2, ZDDAlgorithm.ZddOps.RESTRICT.ordinal(), res);
        return res;
    }

    private V applyOperation(V f1, V f2, ZDDAlgorithm.ZddOps op) {
        V lowF1 = f1;
        V lowF2 = f2;
        V highF1 = f1;
        V highF2 = f2;
        if (this.level(f1) <= this.level(f2)) {
            lowF1 = this.getLow(f1);
            highF1 = this.getHigh(f1);
        }
        if (this.level(f2) <= this.level(f1)) {
            lowF2 = this.getLow(f2);
            highF2 = this.getHigh(f2);
        }
        int var = this.level(f1) <= this.level(f2) ? f1.getVariable() : f2.getVariable();
        Optional<V> lowCheck = this.operationCheck(lowF1, lowF2, op);
        Optional<V> highCheck = this.operationCheck(highF1, highF2, op);
        if (lowCheck.isEmpty() && highCheck.isEmpty() && this.parallelismManager.canFork(this.level(f1))) {
            ForkJoinTask<DD> lowTask = this.createTask((z1, z2) -> this.apply(z1, z2, op), lowF1, lowF2);
            ForkJoinTask<DD> highTask = this.createTask((z1, z2) -> this.apply(z1, z2, op), highF1, highF2);
            return (V)this.makeNode(this.extract(lowTask), this.extract(highTask), var);
        }
        DD low = lowCheck.isPresent() ? (DD)lowCheck.get() : this.apply(lowF1, lowF2, op);
        DD high = highCheck.isPresent() ? (DD)highCheck.get() : this.apply(highF1, highF2, op);
        return (V)this.makeNode(low, high, var);
    }

    private V apply(V f1, V f2, ZDDAlgorithm.ZddOps op) {
        switch (op) {
            case UNION: {
                return this.asyncUnion(f1, f2);
            }
            case DIFF: {
                return this.asyncDifference(f1, f2);
            }
            case INTSEC: {
                return this.asyncIntersect(f1, f2);
            }
            case MUL: {
                return this.asyncProduct(f1, f2);
            }
            case DIV: {
                return this.asyncDivision(f1, f2);
            }
            case EXCLUDE: {
                return this.asyncExclude(f1, f2);
            }
            case RESTRICT: {
                return this.asyncRestrict(f1, f2);
            }
            case SUB_SET1: 
            case CHANGE: 
            case SUB_SET0: {
                return this.asyncShannon(f1, f2, op);
            }
        }
        throw new IllegalArgumentException("Unknown Operator");
    }

    private Optional<V> trySerialComputation(V f1, V f2, ZDDAlgorithm.ZddOps op) {
        return this.operationCheck(f1, f2, op).map(Optional::of).orElseGet(() -> {
            if (!this.parallelismManager.canFork(this.level(f1))) {
                return Optional.of(this.apply(f1, f2, op));
            }
            return Optional.empty();
        });
    }

    private ForkJoinTask<V> createTask(BiFunction<V, V, V> operation, V f1, V f2) {
        this.parallelismManager.taskSupplied();
        return this.parallelismManager.getThreadPool().submit(() -> (DD)operation.apply(f1, f2));
    }

    private V extract(ForkJoinTask<V> task) {
        DD res = (DD)task.join();
        this.parallelismManager.taskDone();
        return (V)res;
    }

    @Override
    public void shutdown() {
        super.shutdown();
        this.parallelismManager.shutdown();
    }
}

