package at.jku.risc.stout.nau.algo;

import at.jku.risc.stout.nau.data.EquationSystem;
import at.jku.risc.stout.nau.data.FreshnessCtx;
import at.jku.risc.stout.nau.data.NodeFactory;
import at.jku.risc.stout.nau.data.atom.Abstraction;
import at.jku.risc.stout.nau.data.atom.Atom;
import at.jku.risc.stout.nau.data.atom.FunctionApplication;
import at.jku.risc.stout.nau.data.atom.HasSort;
import at.jku.risc.stout.nau.data.atom.NominalTerm;
import at.jku.risc.stout.nau.data.atom.Permutation;
import at.jku.risc.stout.nau.data.atom.Sort;
import at.jku.risc.stout.nau.data.atom.Suspension;
import at.jku.risc.stout.nau.data.atom.Variable;
import at.jku.risc.stout.nau.util.ControlledException;
import at.jku.risc.stout.nau.util.DataStructureFactory;
import java.io.IOException;
import java.io.PrintStream;
import java.io.StringWriter;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:at/jku/risc/stout/nau/algo/EquivarianceSystem.class */
public class EquivarianceSystem {
    private Set<? extends Atom> atoms;
    private FreshnessCtx nabla;
    private NodeFactory factory;
    private Permutation pi;
    private EquationSystem<EquivarianceProblem> problemSet;
    private Map<Atom, Atom> atomMap;
    private List<Atom> atomList;

    public EquivarianceSystem(NodeFactory nodeFactory) {
        this(nodeFactory, new EquationSystem<EquivarianceProblem>() { // from class: at.jku.risc.stout.nau.algo.EquivarianceSystem.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // at.jku.risc.stout.nau.data.EquationSystem
            public EquivarianceProblem newEquation(NominalTerm nominalTerm, NominalTerm nominalTerm2) {
                return new EquivarianceProblem(nominalTerm, nominalTerm2);
            }
        });
    }

    public EquivarianceSystem(NodeFactory nodeFactory, EquationSystem<EquivarianceProblem> equationSystem) {
        this.pi = new Permutation();
        this.atomMap = DataStructureFactory.$.newMap();
        this.atomList = DataStructureFactory.$.newList();
        this.factory = nodeFactory;
        this.problemSet = equationSystem;
    }

    public void start(Collection<? extends Atom> collection, FreshnessCtx freshnessCtx) {
        this.atoms = DataStructureFactory.$.newSet(collection);
        this.nabla = freshnessCtx;
    }

    public EquivarianceSystem clear() {
        this.pi.clear();
        this.problemSet.clear();
        this.atomMap.clear();
        this.atomList.clear();
        return this;
    }

    public void addEquation(NominalTerm nominalTerm, NominalTerm nominalTerm2, boolean z) {
        if (z) {
            this.problemSet.add(this.problemSet.newEquation(nominalTerm.deepCopy2(), nominalTerm2.deepCopy2()));
        } else {
            this.problemSet.add(this.problemSet.newEquation(nominalTerm, nominalTerm2));
        }
    }

    public Permutation compute() {
        try {
            return compute(DebugLevel.SILENT, null);
        } catch (ControlledException e) {
            e.printStackTrace();
            return null;
        }
    }

    public Permutation compute(DebugLevel debugLevel, PrintStream printStream) throws ControlledException {
        debug(null, debugLevel, printStream, "Starting phase 1 (simplification) of equivariance algorithm", "");
        debug(debugLevel, printStream);
        while (!this.problemSet.isEmpty()) {
            EquivarianceProblem popLast = this.problemSet.popLast();
            NominalTerm left = popLast.getLeft();
            NominalTerm right = popLast.getRight();
            HasSort<? extends Sort> head = left.getHead();
            HasSort<? extends Sort> head2 = right.getHead();
            if (left.getClass() != right.getClass() || head.getSort2() != head2.getSort2()) {
                debug("HEAD-CLASH", debugLevel, printStream, "No rule applicable to " + popLast);
                return null;
            }
            if (left instanceof Abstraction) {
                Abstraction abstraction = (Abstraction) left;
                Abstraction abstraction2 = (Abstraction) right;
                NominalTerm subTerm = abstraction.getSubTerm();
                NominalTerm subTerm2 = abstraction2.getSubTerm();
                Atom obtainFreshAtom = this.factory.obtainFreshAtom(abstraction.getBoundAtom().getSort2());
                popLast.setLeft(subTerm.swap(abstraction.getBoundAtom(), obtainFreshAtom));
                popLast.setRight(subTerm2.swap(abstraction2.getBoundAtom(), obtainFreshAtom));
                this.problemSet.add(popLast);
                debug("ALP-E", debugLevel, printStream, new String[0]);
            } else if (left instanceof FunctionApplication) {
                FunctionApplication functionApplication = (FunctionApplication) left;
                FunctionApplication functionApplication2 = (FunctionApplication) right;
                if (functionApplication.getFncSymb() != functionApplication2.getFncSymb()) {
                    debug("HEAD-CLASH", debugLevel, printStream, "No rule applicable to " + popLast);
                    return null;
                }
                NominalTerm[] args = functionApplication.getArgs();
                NominalTerm[] args2 = functionApplication2.getArgs();
                for (int length = args.length - 1; length >= 0; length--) {
                    addEquation(args[length], args2[length], false);
                }
                debug("DEC-E", debugLevel, printStream, new String[0]);
            } else if (left instanceof Suspension) {
                Suspension suspension = (Suspension) left;
                Suspension suspension2 = (Suspension) right;
                Variable var = suspension.getVar();
                if (var != suspension2.getVar()) {
                    debug("HEAD-CLASH", debugLevel, printStream, "No rule applicable to " + popLast);
                    return null;
                }
                Permutation perm = suspension.getPerm();
                Permutation perm2 = suspension2.getPerm();
                debug("SUS-E", debugLevel, printStream, new String[0]);
                for (Atom atom : this.atoms) {
                    if (!this.nabla.contains(atom, var) && !putAtom(debugLevel, printStream, perm.permute(atom), perm2.permute(atom))) {
                        return null;
                    }
                }
            } else if (!putAtom(debugLevel, printStream, (Atom) left, (Atom) right)) {
                return null;
            }
            debug(debugLevel, printStream);
        }
        debug(null, debugLevel, printStream, "Starting phase 2 (permutation) of equivariance algorithm", "");
        for (Atom atom2 : this.atomList) {
            Atom remove = this.atomMap.remove(atom2);
            Atom permute = this.pi.permute(atom2);
            if (permute == remove) {
                this.atoms.remove(permute);
                debug("REM-E", debugLevel, printStream, new String[0]);
            } else {
                if (!this.atoms.contains(permute) || !this.atoms.remove(remove)) {
                    debug("PERM-CLASH", debugLevel, printStream, "Permutation application not possible to " + permute + EquivarianceProblem.eqSeparator + remove);
                    return null;
                }
                this.pi.addSwappingHead(permute, remove);
                debug("SOL-E", debugLevel, printStream, new String[0]);
            }
            debug(debugLevel, printStream);
        }
        return this.pi;
    }

    public boolean putAtom(DebugLevel debugLevel, PrintStream printStream, Atom atom, Atom atom2) {
        Atom put = this.atomMap.put(atom, atom2);
        if (put == null) {
            this.atomList.add(atom);
            return true;
        }
        if (put == atom2) {
            return true;
        }
        debug("ATOM-CLASH", debugLevel, printStream, "Swapping conflict: " + atom + EquivarianceProblem.eqSeparator + put + "; " + atom + EquivarianceProblem.eqSeparator + atom2);
        return false;
    }

    private void debug(DebugLevel debugLevel, PrintStream printStream) {
        if (debugLevel == DebugLevel.PROGRESS) {
            printStream.println("    Atoms: " + this.atoms);
            printStream.println("Equations: " + problemToString());
            printStream.println("       Pi: " + this.pi.toString(true));
            printStream.println();
        }
    }

    private void debug(String str, DebugLevel debugLevel, PrintStream printStream, String... strArr) {
        if (debugLevel == DebugLevel.PROGRESS) {
            if (str != null) {
                printStream.println(String.valueOf(str) + " ==> ");
            }
            for (String str2 : strArr) {
                printStream.println(str2);
            }
        }
    }

    private String problemToString() {
        StringWriter stringWriter = new StringWriter();
        try {
            this.problemSet.printString(stringWriter);
            int size = this.problemSet.size();
            for (Map.Entry<Atom, Atom> entry : this.atomMap.entrySet()) {
                EquationSystem.writePrefix(stringWriter, size);
                EquivarianceProblem.printString(stringWriter, entry.getKey(), entry.getValue());
                size++;
            }
            return stringWriter.toString();
        } catch (IOException e) {
            return e.toString();
        }
    }

    public String toString() {
        return this.pi.toString();
    }
}
