/*
 * Decompiled with CFR 0.152.
 */
package gadgetinspector;

import gadgetinspector.ImplementationFinder;
import gadgetinspector.config.GIConfig;
import gadgetinspector.config.JavaDeserializationConfig;
import gadgetinspector.data.ClassReference;
import gadgetinspector.data.DataLoader;
import gadgetinspector.data.GraphCall;
import gadgetinspector.data.InheritanceDeriver;
import gadgetinspector.data.InheritanceMap;
import gadgetinspector.data.MethodReference;
import gadgetinspector.data.Source;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GadgetChainDiscovery {
    private static final Logger LOGGER = LoggerFactory.getLogger(GadgetChainDiscovery.class);
    private final GIConfig config;

    public GadgetChainDiscovery(GIConfig config) {
        this.config = config;
    }

    public void discover() throws Exception {
        Map<MethodReference.Handle, MethodReference> methodMap = DataLoader.loadMethods();
        InheritanceMap inheritanceMap = InheritanceMap.load();
        Map<MethodReference.Handle, Set<MethodReference.Handle>> methodImplMap = InheritanceDeriver.getAllMethodImplementations(inheritanceMap, methodMap);
        ImplementationFinder implementationFinder = this.config.getImplementationFinder(methodMap, methodImplMap, inheritanceMap);
        try (BufferedWriter writer = Files.newBufferedWriter(Paths.get("methodimpl.dat", new String[0]), new OpenOption[0]);){
            for (Map.Entry<MethodReference.Handle, Set<MethodReference.Handle>> entry : methodImplMap.entrySet()) {
                writer.write(entry.getKey().getClassReference().getName());
                writer.write("\t");
                writer.write(entry.getKey().getName());
                writer.write("\t");
                writer.write(entry.getKey().getDesc());
                writer.write("\n");
                for (MethodReference.Handle method : entry.getValue()) {
                    writer.write("\t");
                    writer.write(method.getClassReference().getName());
                    writer.write("\t");
                    writer.write(method.getName());
                    writer.write("\t");
                    writer.write(method.getDesc());
                    writer.write("\n");
                }
            }
        }
        HashMap graphCallMap = new HashMap();
        for (GraphCall graphCall : DataLoader.loadData(Paths.get("callgraph.dat", new String[0]), new GraphCall.Factory())) {
            MethodReference.Handle caller = graphCall.getCallerMethod();
            if (!graphCallMap.containsKey(caller)) {
                HashSet<GraphCall> graphCalls = new HashSet<GraphCall>();
                graphCalls.add(graphCall);
                graphCallMap.put(caller, graphCalls);
                continue;
            }
            ((Set)graphCallMap.get(caller)).add(graphCall);
        }
        HashSet<GadgetChainLink> exploredMethods = new HashSet<GadgetChainLink>();
        LinkedList<GadgetChain> methodsToExplore = new LinkedList<GadgetChain>();
        for (Source source : DataLoader.loadData(Paths.get("sources.dat", new String[0]), new Source.Factory())) {
            GadgetChainLink srcLink = new GadgetChainLink(source.getSourceMethod(), source.getTaintedArgIndex());
            if (exploredMethods.contains(srcLink)) continue;
            methodsToExplore.add(new GadgetChain(Arrays.asList(srcLink)));
            exploredMethods.add(srcLink);
        }
        long iteration = 0L;
        HashSet<GadgetChain> discoveredGadgets = new HashSet<GadgetChain>();
        while (methodsToExplore.size() > 0) {
            if (iteration % 1000L == 0L) {
                LOGGER.info("Iteration " + iteration + ", Search space: " + methodsToExplore.size());
            }
            ++iteration;
            GadgetChain chain = (GadgetChain)methodsToExplore.pop();
            GadgetChainLink lastLink = chain.links.get(chain.links.size() - 1);
            Set methodCalls = (Set)graphCallMap.get(lastLink.method);
            if (methodCalls == null) continue;
            for (GraphCall graphCall : methodCalls) {
                if (graphCall.getCallerArgIndex() != lastLink.taintedArgIndex) continue;
                Set<MethodReference.Handle> allImpls = implementationFinder.getImplementations(graphCall.getTargetMethod());
                for (MethodReference.Handle methodImpl : allImpls) {
                    GadgetChainLink newLink = new GadgetChainLink(methodImpl, graphCall.getTargetArgIndex());
                    if (exploredMethods.contains(newLink)) continue;
                    GadgetChain newChain = new GadgetChain(chain, newLink);
                    if (this.isSink(methodImpl, graphCall.getTargetArgIndex(), inheritanceMap)) {
                        discoveredGadgets.add(newChain);
                        continue;
                    }
                    methodsToExplore.add(newChain);
                    exploredMethods.add(newLink);
                }
            }
        }
        try (OutputStream outputStream = Files.newOutputStream(Paths.get("gadget-chains.txt", new String[0]), new OpenOption[0]);
             OutputStreamWriter writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8);){
            for (GadgetChain chain : discoveredGadgets) {
                GadgetChainDiscovery.printGadgetChain(writer, chain);
            }
        }
        LOGGER.info("Found {} gadget chains.", (Object)discoveredGadgets.size());
    }

    private static void printGadgetChain(Writer writer, GadgetChain chain) throws IOException {
        writer.write(String.format("%s.%s%s (%d)%n", chain.links.get((int)0).method.getClassReference().getName(), chain.links.get((int)0).method.getName(), chain.links.get((int)0).method.getDesc(), chain.links.get((int)0).taintedArgIndex));
        for (int i = 1; i < chain.links.size(); ++i) {
            writer.write(String.format("  %s.%s%s (%d)%n", chain.links.get((int)i).method.getClassReference().getName(), chain.links.get((int)i).method.getName(), chain.links.get((int)i).method.getDesc(), chain.links.get((int)i).taintedArgIndex));
        }
        writer.write("\n");
    }

    private boolean isSink(MethodReference.Handle method, int argIndex, InheritanceMap inheritanceMap) {
        if (method.getClassReference().getName().equals("java/io/FileInputStream") && method.getName().equals("<init>")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/io/FileOutputStream") && method.getName().equals("<init>")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/nio/file/Files") && (method.getName().equals("newInputStream") || method.getName().equals("newOutputStream") || method.getName().equals("newBufferedReader") || method.getName().equals("newBufferedWriter"))) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/lang/Runtime") && method.getName().equals("exec")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/lang/reflect/Method") && method.getName().equals("invoke") && argIndex == 0) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/net/URLClassLoader") && method.getName().equals("newInstance")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/lang/System") && method.getName().equals("exit")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/lang/Shutdown") && method.getName().equals("exit")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/lang/Runtime") && method.getName().equals("exit")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/nio/file/Files") && method.getName().equals("newOutputStream")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/lang/ProcessBuilder") && method.getName().equals("<init>") && argIndex > 0) {
            return true;
        }
        if (inheritanceMap.isSubclassOf(method.getClassReference(), new ClassReference.Handle("java/lang/ClassLoader")) && method.getName().equals("<init>")) {
            return true;
        }
        if (method.getClassReference().getName().equals("java/net/URL") && method.getName().equals("openStream")) {
            return true;
        }
        if (method.getClassReference().getName().equals("org/codehaus/groovy/runtime/InvokerHelper") && method.getName().equals("invokeMethod") && argIndex == 1) {
            return true;
        }
        if (inheritanceMap.isSubclassOf(method.getClassReference(), new ClassReference.Handle("groovy/lang/MetaClass")) && Arrays.asList("invokeMethod", "invokeConstructor", "invokeStaticMethod").contains(method.getName())) {
            return true;
        }
        return method.getClassReference().getName().equals("org/python/core/PyCode") && method.getName().equals("call");
    }

    public static void main(String[] args) throws Exception {
        GadgetChainDiscovery gadgetChainDiscovery = new GadgetChainDiscovery(new JavaDeserializationConfig());
        gadgetChainDiscovery.discover();
    }

    private static class GadgetChainLink {
        private final MethodReference.Handle method;
        private final int taintedArgIndex;

        private GadgetChainLink(MethodReference.Handle method, int taintedArgIndex) {
            this.method = method;
            this.taintedArgIndex = taintedArgIndex;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            GadgetChainLink that = (GadgetChainLink)o;
            if (this.taintedArgIndex != that.taintedArgIndex) {
                return false;
            }
            return this.method != null ? this.method.equals(that.method) : that.method == null;
        }

        public int hashCode() {
            int result = this.method != null ? this.method.hashCode() : 0;
            result = 31 * result + this.taintedArgIndex;
            return result;
        }
    }

    private static class GadgetChain {
        private final List<GadgetChainLink> links;

        private GadgetChain(List<GadgetChainLink> links) {
            this.links = links;
        }

        private GadgetChain(GadgetChain gadgetChain, GadgetChainLink link) {
            ArrayList<GadgetChainLink> links = new ArrayList<GadgetChainLink>(gadgetChain.links);
            links.add(link);
            this.links = links;
        }
    }
}

