/*
 * Decompiled with CFR 0.152.
 */
package org.cf.simplify;

import com.google.common.primitives.Ints;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TIntObjectMap;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.cf.smalivm.SideEffect;
import org.cf.smalivm.VirtualMachine;
import org.cf.smalivm.context.ExecutionContext;
import org.cf.smalivm.context.ExecutionGraph;
import org.cf.smalivm.context.ExecutionNode;
import org.cf.smalivm.context.MethodState;
import org.cf.smalivm.opcode.FillArrayDataPayloadOp;
import org.cf.smalivm.opcode.InvokeOp;
import org.cf.smalivm.opcode.NewInstanceOp;
import org.cf.smalivm.opcode.NopOp;
import org.cf.smalivm.opcode.Op;
import org.cf.smalivm.opcode.OpCreator;
import org.cf.smalivm.opcode.ReturnOp;
import org.cf.smalivm.opcode.ReturnVoidOp;
import org.cf.smalivm.opcode.SwitchPayloadOp;
import org.cf.smalivm.type.VirtualMethod;
import org.jf.dexlib2.builder.BuilderInstruction;
import org.jf.dexlib2.builder.BuilderTryBlock;
import org.jf.dexlib2.builder.ItemWithLocation;
import org.jf.dexlib2.builder.Label;
import org.jf.dexlib2.builder.MethodLocation;
import org.jf.dexlib2.builder.MutableMethodImplementation;
import org.jf.dexlib2.writer.builder.DexBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExecutionGraphManipulator
extends ExecutionGraph {
    private static final Logger log = LoggerFactory.getLogger(ExecutionGraphManipulator.class.getSimpleName());
    private final DexBuilder dexBuilder;
    private final MutableMethodImplementation implementation;
    private final VirtualMethod method;
    private final VirtualMachine vm;
    private final Set<MethodLocation> recreateLocations;
    private final List<MethodLocation> reexecuteLocations;
    private final OpCreator opCreator;
    private boolean recreateOrExecuteAgain;

    public ExecutionGraphManipulator(ExecutionGraph graph, VirtualMethod method, VirtualMachine vm, DexBuilder dexBuilder) {
        super(graph, true);
        this.dexBuilder = dexBuilder;
        this.method = method;
        this.implementation = method.getImplementation();
        this.vm = vm;
        this.opCreator = ExecutionGraphManipulator.getOpCreator(vm, this.addressToLocation);
        this.recreateLocations = new HashSet<MethodLocation>();
        this.reexecuteLocations = new LinkedList<MethodLocation>();
        this.recreateOrExecuteAgain = true;
    }

    public void addInstruction(MethodLocation location, BuilderInstruction instruction) {
        int index = location.getIndex();
        this.implementation.addInstruction(index, instruction);
        MethodLocation newLocation = instruction.getLocation();
        MethodLocation oldLocation = ((BuilderInstruction)this.implementation.getInstructions().get(index + 1)).getLocation();
        try {
            Method m3 = MethodLocation.class.getDeclaredMethod("mergeInto", MethodLocation.class);
            m3.setAccessible(true);
            m3.invoke((Object)oldLocation, newLocation);
        }
        catch (Exception e) {
            log.error("Error invoking MethodLocation.mergeInto(). Wrong dexlib version?", e);
        }
        this.rebuildGraph();
    }

    public void addInstruction(int address, BuilderInstruction newInstruction) {
        this.addInstruction(this.getLocation(address), newInstruction);
    }

    public int[] getAvailableRegisters(int address) {
        int[] registers = new int[this.getRegisterCount(address)];
        for (int i = 0; i < registers.length; ++i) {
            registers[i] = i;
        }
        ArrayDeque<ExecutionNode> stack = new ArrayDeque<ExecutionNode>(this.getChildren(address));
        ExecutionNode node = (ExecutionNode)stack.peek();
        if (null == node) {
            assert (this.getTemplateNode(address).getOp() instanceof ReturnOp || this.getTemplateNode(address).getOp() instanceof ReturnVoidOp);
            return registers;
        }
        HashSet<Integer> registersRead = new HashSet<Integer>();
        HashSet<Integer> registersAssigned = new HashSet<Integer>();
        while ((node = (ExecutionNode)stack.poll()) != null) {
            MethodState mState = node.getContext().getMethodState();
            int[] nArray = registers;
            int n = nArray.length;
            for (int i = 0; i < n; ++i) {
                Integer register = nArray[i];
                if (registersRead.contains(register) || registersAssigned.contains(register) || node.getOp().getName().startsWith("move-result")) continue;
                if (mState.wasRegisterRead(register)) {
                    registersRead.add(register);
                    continue;
                }
                if (!mState.wasRegisterAssigned(register)) continue;
                registersAssigned.add(register);
            }
            stack.addAll(node.getChildren());
        }
        return Arrays.stream(registers).filter(r -> !registersRead.contains(r)).toArray();
    }

    public List<ExecutionNode> getChildren(int address) {
        ArrayList<ExecutionNode> children = new ArrayList<ExecutionNode>();
        List<ExecutionNode> nodePile = this.getNodePile(address);
        for (ExecutionNode node : nodePile) {
            children.addAll(node.getChildren());
        }
        return children;
    }

    public DexBuilder getDexBuilder() {
        return this.dexBuilder;
    }

    @Nullable
    public BuilderInstruction getInstruction(int address) {
        ExecutionNode node = this.getTemplateNode(address);
        return node.getOp().getInstruction();
    }

    public int[] getParentAddresses(int address) {
        HashSet<Integer> parentAddresses = new HashSet<Integer>();
        for (ExecutionNode node : this.getNodePile(address)) {
            ExecutionNode parent = node.getParent();
            if (null == parent) continue;
            parentAddresses.add(parent.getAddress());
        }
        return Ints.toArray(parentAddresses);
    }

    public List<BuilderTryBlock> getTryBlocks() {
        return this.implementation.getTryBlocks();
    }

    @Override
    public VirtualMachine getVM() {
        return this.vm;
    }

    public void removeInstruction(MethodLocation location) {
        int index = location.getIndex();
        this.implementation.removeInstruction(index);
        this.removeEmptyTryCatchBlocks();
        this.rebuildGraph();
    }

    public void removeInstruction(int address) {
        this.removeInstruction(this.getLocation(address));
    }

    public void removeInstructions(List<Integer> addresses) {
        Collections.sort(addresses);
        Collections.reverse(addresses);
        log.debug("Removing instructions: {}", (Object)addresses);
        addresses.forEach(this::removeInstruction);
    }

    public void replaceInstruction(int insertAddress, BuilderInstruction instruction) {
        LinkedList<BuilderInstruction> instructions = new LinkedList<BuilderInstruction>();
        instructions.add(instruction);
        this.replaceInstruction(insertAddress, instructions);
    }

    public void replaceInstruction(int insertAddress, List<BuilderInstruction> instructions) {
        this.recreateOrExecuteAgain = false;
        int address = insertAddress;
        for (BuilderInstruction instruction : instructions) {
            this.addInstruction(address, instruction);
            address += instruction.getCodeUnits();
        }
        MethodLocation location = this.getLocation(address);
        this.recreateOrExecuteAgain = true;
        this.removeInstruction(location);
    }

    public MethodLocation getLocation(int address) {
        return (MethodLocation)this.addressToLocation.get(address);
    }

    private int getRegisterCount(int address) {
        return this.getNodePile(address).get(0).getContext().getMethodState().getRegisterCount();
    }

    private void addToNodePile(MethodLocation newLocation) {
        int oldIndex = newLocation.getIndex() + 1;
        MethodLocation shiftedLocation = null;
        for (MethodLocation location : this.locationToNodePile.keySet()) {
            if (location.getIndex() != oldIndex) continue;
            shiftedLocation = location;
            break;
        }
        assert (shiftedLocation != null);
        List shiftedNodePile = (List)this.locationToNodePile.get(shiftedLocation);
        ArrayList<ExecutionNode> newNodePile = new ArrayList<ExecutionNode>();
        this.locationToNodePile.put(newLocation, newNodePile);
        Op shiftedOp = ((ExecutionNode)shiftedNodePile.get(0)).getOp();
        Op op = this.opCreator.create(newLocation);
        this.recreateLocations.add(newLocation);
        this.reexecuteLocations.add(newLocation);
        boolean autoAddedPadding = op instanceof NopOp && (shiftedOp instanceof FillArrayDataPayloadOp || shiftedOp instanceof SwitchPayloadOp);
        for (int i = 0; i < shiftedNodePile.size(); ++i) {
            ExecutionNode newNode = new ExecutionNode(op);
            newNodePile.add(i, newNode);
            if (autoAddedPadding) break;
            if (i == 0) continue;
            ExecutionNode shiftedNode = (ExecutionNode)shiftedNodePile.get(i);
            ExecutionNode shiftedParent = shiftedNode.getParent();
            if (shiftedParent != null) {
                shiftedParent.removeChild(shiftedNode);
                this.reparentNode(newNode, shiftedParent);
                this.recreateLocations.add(shiftedParent.getOp().getLocation());
            } else {
                assert (0 == newLocation.getCodeAddress());
                ExecutionContext newContext = this.vm.spawnRootContext(this.method);
                newNode.setContext(newContext);
            }
            this.reparentNode(shiftedNode, newNode);
        }
    }

    private void reparentNode(@Nonnull ExecutionNode child, @Nonnull ExecutionNode parent) {
        ExecutionContext newContext = parent.getContext().spawnChild();
        child.setContext(newContext);
        child.setParent(parent);
        this.reexecuteLocations.add(child.getOp().getLocation());
        for (ExecutionNode grandChild : child.getChildren()) {
            grandChild.getContext().setShallowParent(newContext);
        }
    }

    private void recreateAndExecute() {
        List pile;
        if (!this.recreateOrExecuteAgain) {
            return;
        }
        this.recreateLocations.removeIf(p -> p.getInstruction() == null);
        this.reexecuteLocations.removeIf(p -> p.getInstruction() == null);
        for (MethodLocation location : this.recreateLocations) {
            Op op = this.opCreator.create(location);
            pile = (List)this.locationToNodePile.get(location);
            if (op instanceof NewInstanceOp || op instanceof InvokeOp) {
                ExecutionNode node = (ExecutionNode)pile.get(0);
                try {
                    SideEffect.Level originalLevel = node.getOp().getSideEffectLevel();
                    Class klazz = op instanceof NewInstanceOp ? NewInstanceOp.class : InvokeOp.class;
                    Field f = klazz.getDeclaredField("sideEffectLevel");
                    f.setAccessible(true);
                    f.set(op, (Object)originalLevel);
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
            for (ExecutionNode aPile : pile) {
                aPile.setOp(op);
            }
        }
        Collections.sort(this.reexecuteLocations, (e1, e2) -> Integer.compare(e1.getCodeAddress(), e2.getCodeAddress()));
        LinkedHashSet<MethodLocation> reexecute = new LinkedHashSet<MethodLocation>(this.reexecuteLocations);
        for (MethodLocation location : reexecute) {
            pile = (List)this.locationToNodePile.get(location);
            for (int i = 0; i < pile.size(); ++i) {
                ExecutionNode node = (ExecutionNode)pile.get(i);
                if (i == 0) continue;
                node.execute();
            }
        }
        this.recreateLocations.clear();
        this.reexecuteLocations.clear();
    }

    private void rebuildGraph() {
        Set staleLocations = this.locationToNodePile.keySet();
        Set implementationLocations = this.implementation.getInstructions().stream().map(BuilderInstruction::getLocation).collect(Collectors.toSet());
        HashSet addedLocations = new HashSet(implementationLocations);
        addedLocations.removeAll(staleLocations);
        for (MethodLocation location : addedLocations) {
            this.addToNodePile(location);
        }
        HashSet removedLocations = new HashSet(staleLocations);
        removedLocations.removeAll(implementationLocations);
        removedLocations.forEach(this::removeFromNodePile);
        TIntObjectMap<MethodLocation> newAddressToLocation = ExecutionGraphManipulator.buildAddressToLocation(this.implementation);
        this.addressToLocation.clear();
        this.addressToLocation.putAll(newAddressToLocation);
        this.recreateAndExecute();
    }

    private void removeEmptyTryCatchBlocks() {
        ListIterator<BuilderTryBlock> iter = this.implementation.getTryBlocks().listIterator();
        TIntArrayList removeIndexes = new TIntArrayList();
        while (iter.hasNext()) {
            int index = iter.nextIndex();
            BuilderTryBlock tryBlock = iter.next();
            MethodLocation start = this.getLocation(tryBlock.start);
            MethodLocation end = this.getLocation(tryBlock.end);
            if (start != null && end != null && start.getCodeAddress() != end.getCodeAddress()) continue;
            removeIndexes.add(index);
        }
        ArrayList tryBlocks = null;
        try {
            Field f = this.implementation.getClass().getDeclaredField("tryBlocks");
            f.setAccessible(true);
            tryBlocks = (ArrayList)f.get(this.implementation);
        }
        catch (IllegalAccessException | IllegalArgumentException | NoSuchFieldException | SecurityException e) {
            e.printStackTrace();
        }
        removeIndexes.sort();
        removeIndexes.reverse();
        for (int index : removeIndexes.toArray()) {
            tryBlocks.remove(index);
        }
    }

    @Nullable
    private MethodLocation getLocation(Label label) {
        try {
            Field f = ItemWithLocation.class.getDeclaredField("location");
            f.setAccessible(true);
            return (MethodLocation)f.get(label);
        }
        catch (Exception e) {
            log.error("Couldn't get label location.", e);
            return null;
        }
    }

    private void removeFromNodePile(MethodLocation location) {
        List nodePile = (List)this.locationToNodePile.remove(location);
        HashMap<MethodLocation, ExecutionNode> locationToChildNodeToRemove = new HashMap<MethodLocation, ExecutionNode>();
        for (ExecutionNode executionNode : nodePile) {
            ExecutionNode parentNode = executionNode.getParent();
            if (parentNode == null) continue;
            parentNode.removeChild(executionNode);
            this.recreateLocations.add(parentNode.getOp().getLocation());
            for (ExecutionNode childNode : executionNode.getChildren()) {
                boolean pseudoChild;
                Op childOp = childNode.getOp();
                boolean bl = pseudoChild = childOp instanceof FillArrayDataPayloadOp || childOp instanceof SwitchPayloadOp;
                if (!pseudoChild) {
                    this.reparentNode(childNode, parentNode);
                    continue;
                }
                for (ExecutionNode grandChildNode : childNode.getChildren()) {
                    this.reparentNode(grandChildNode, parentNode);
                }
                locationToChildNodeToRemove.put(childOp.getLocation(), childNode);
            }
        }
        for (Map.Entry entry : locationToChildNodeToRemove.entrySet()) {
            List pile = (List)this.locationToNodePile.get(entry.getKey());
            pile.remove(entry.getValue());
        }
    }
}

