Commit 10de4ff4 authored by Skylot's avatar Skylot

core: process dependant classes before code generation

parent eed65421
......@@ -2,6 +2,7 @@ package jadx.api;
import jadx.core.Jadx;
import jadx.core.ProcessClass;
import jadx.core.codegen.CodeGen;
import jadx.core.codegen.CodeWriter;
import jadx.core.deobf.DefaultDeobfuscator;
import jadx.core.deobf.Deobfuscator;
......@@ -57,6 +58,8 @@ public final class JadxDecompiler {
private RootNode root;
private List<IDexTreeVisitor> passes;
private CodeGen codeGen;
private List<JavaClass> classes;
private List<ResourceFile> resources;
......@@ -83,6 +86,7 @@ public final class JadxDecompiler {
outDir = new DefaultJadxArgs().getOutDir();
}
this.passes = Jadx.getPassesList(args, outDir);
this.codeGen = new CodeGen(args);
}
void reset() {
......@@ -305,7 +309,7 @@ public final class JadxDecompiler {
}
void processClass(ClassNode cls) {
ProcessClass.process(cls, passes);
ProcessClass.process(cls, passes, codeGen);
}
RootNode getRoot() {
......@@ -331,6 +335,10 @@ public final class JadxDecompiler {
return null;
}
public IJadxArgs getArgs() {
return args;
}
@Override
public String toString() {
return "jadx decompiler " + getVersion();
......
package jadx.core;
import jadx.api.IJadxArgs;
import jadx.core.codegen.CodeGen;
import jadx.core.dex.visitors.ClassModifier;
import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.dex.visitors.ConstInlineVisitor;
import jadx.core.dex.visitors.DebugInfoVisitor;
import jadx.core.dex.visitors.DependencyCollector;
import jadx.core.dex.visitors.DotGraphVisitor;
import jadx.core.dex.visitors.EnumVisitor;
import jadx.core.dex.visitors.FallbackModeVisitor;
......@@ -104,8 +104,9 @@ public class Jadx {
passes.add(new PrepareForCodeGen());
passes.add(new LoopRegionVisitor());
passes.add(new ProcessVariables());
passes.add(new DependencyCollector());
}
passes.add(new CodeGen(args));
return passes;
}
......
package jadx.core;
import jadx.core.codegen.CodeGen;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.visitors.DepthTraversal;
import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.utils.ErrorsCounter;
import java.util.List;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static jadx.core.dex.nodes.ProcessState.GENERATED;
import static jadx.core.dex.nodes.ProcessState.NOT_LOADED;
import static jadx.core.dex.nodes.ProcessState.PROCESSED;
import static jadx.core.dex.nodes.ProcessState.STARTED;
import static jadx.core.dex.nodes.ProcessState.UNLOADED;
public final class ProcessClass {
private static final Logger LOG = LoggerFactory.getLogger(ProcessClass.class);
private ProcessClass() {
}
public static void process(ClassNode cls, List<IDexTreeVisitor> passes) {
try {
cls.load();
for (IDexTreeVisitor visitor : passes) {
DepthTraversal.visit(visitor, cls);
public static void process(ClassNode cls, List<IDexTreeVisitor> passes, @Nullable CodeGen codeGen) {
synchronized (cls) {
try {
if (cls.getState() == NOT_LOADED) {
cls.load();
cls.setState(STARTED);
for (IDexTreeVisitor visitor : passes) {
DepthTraversal.visit(visitor, cls);
}
for (ClassNode clsNode : cls.getDependencies()) {
process(clsNode, passes, null);
}
cls.setState(PROCESSED);
}
if (cls.getState() == PROCESSED && codeGen != null) {
codeGen.visit(cls);
cls.setState(GENERATED);
}
} catch (Exception e) {
ErrorsCounter.classError(cls, e.getClass().getSimpleName(), e);
} finally {
if (cls.getState() == GENERATED) {
cls.unload();
cls.setState(UNLOADED);
}
}
} catch (Exception e) {
LOG.error("Class process exception: {}", cls, e);
} finally {
cls.unload();
}
}
}
......@@ -24,9 +24,11 @@ import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.jetbrains.annotations.TestOnly;
import org.slf4j.Logger;
......@@ -58,6 +60,9 @@ public class ClassNode extends LineAttrNode implements ILoadable {
// store parent for inner classes or 'this' otherwise
private ClassNode parentClass;
private ProcessState state = ProcessState.NOT_LOADED;
private final Set<ClassNode> dependencies = new HashSet<ClassNode>();
public ClassNode(DexNode dex, ClassDef cls) throws DecodeException {
this.dex = dex;
this.clsInfo = ClassInfo.fromDex(dex, cls.getTypeIndex());
......@@ -452,6 +457,18 @@ public class ClassNode extends LineAttrNode implements ILoadable {
return code;
}
public ProcessState getState() {
return state;
}
public void setState(ProcessState state) {
this.state = state;
}
public Set<ClassNode> getDependencies() {
return dependencies;
}
@Override
public String toString() {
return getFullName();
......
package jadx.core.dex.nodes;
public enum ProcessState {
NOT_LOADED,
STARTED,
PROCESSED,
GENERATED,
UNLOADED
}
package jadx.core.dex.visitors;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.exceptions.JadxException;
import java.util.Set;
public class DependencyCollector extends AbstractVisitor {
@Override
public boolean visit(ClassNode cls) throws JadxException {
DexNode dex = cls.dex();
Set<ClassNode> depList = cls.getDependencies();
processClass(cls, dex, depList);
for (ClassNode inner : cls.getInnerClasses()) {
processClass(inner, dex, depList);
}
depList.remove(cls);
return false;
}
private static void processClass(ClassNode cls, DexNode dex, Set<ClassNode> depList) {
addDep(dex, depList, cls.getSuperClass());
for (ClassInfo clsInfo : cls.getInterfaces()) {
addDep(dex, depList, clsInfo);
}
for (FieldNode fieldNode : cls.getFields()) {
addDep(dex, depList, fieldNode.getType());
}
// TODO: process annotations and generics
for (MethodNode methodNode : cls.getMethods()) {
if (methodNode.isNoCode()) {
continue;
}
processMethod(dex, depList, methodNode);
}
}
private static void processMethod(DexNode dex, Set<ClassNode> depList, MethodNode methodNode) {
addDep(dex, depList, methodNode.getParentClass());
addDep(dex, depList, methodNode.getReturnType());
for (ArgType arg : methodNode.getMethodInfo().getArgumentsTypes()) {
addDep(dex, depList, arg);
}
for (BlockNode block : methodNode.getBasicBlocks()) {
for (InsnNode insnNode : block.getInstructions()) {
processInsn(dex, depList, insnNode);
}
}
}
// TODO: add custom instructions processing
private static void processInsn(DexNode dex, Set<ClassNode> depList, InsnNode insnNode) {
RegisterArg result = insnNode.getResult();
if (result != null) {
addDep(dex, depList, result.getType());
}
for (InsnArg arg : insnNode.getArguments()) {
if (arg.isInsnWrap()) {
processInsn(dex, depList, ((InsnWrapArg) arg).getWrapInsn());
} else {
addDep(dex, depList, arg.getType());
}
}
}
private static void addDep(DexNode dex, Set<ClassNode> depList, ArgType type) {
if (type != null) {
if (type.isObject()) {
addDep(dex, depList, ClassInfo.fromName(type.getObject()));
ArgType[] genericTypes = type.getGenericTypes();
if (type.isGeneric() && genericTypes != null) {
for (ArgType argType : genericTypes) {
addDep(dex, depList, argType);
}
}
} else if (type.isArray()) {
addDep(dex, depList, type.getArrayRootElement());
}
}
}
private static void addDep(DexNode dex, Set<ClassNode> depList, ClassInfo clsInfo) {
if (clsInfo != null) {
ClassNode node = dex.resolveClass(clsInfo);
if (node != null) {
depList.add(node);
}
}
}
private static void addDep(DexNode dex, Set<ClassNode> depList, ClassNode clsNode) {
if (clsNode != null) {
depList.add(clsNode);
}
}
}
......@@ -4,6 +4,8 @@ import jadx.api.DefaultJadxArgs;
import jadx.api.JadxDecompiler;
import jadx.api.JadxInternalAccess;
import jadx.core.Jadx;
import jadx.core.ProcessClass;
import jadx.core.codegen.CodeGen;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.nodes.ClassNode;
......@@ -11,6 +13,7 @@ import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.DepthTraversal;
import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.utils.exceptions.CodegenException;
import jadx.core.utils.exceptions.JadxException;
import jadx.core.utils.files.FileUtils;
import jadx.tests.api.compiler.DynamicCompiler;
......@@ -51,6 +54,7 @@ public abstract class IntegrationTest extends TestUtils {
protected boolean isFallback = false;
protected boolean deleteTmpFiles = true;
protected boolean withDebugInfo = true;
protected boolean unloadCls = true;
protected Map<Integer, String> resMap = Collections.emptyMap();
......@@ -64,16 +68,18 @@ public abstract class IntegrationTest extends TestUtils {
File jar = getJarForClass(clazz);
return getClassNodeFromFile(jar, clazz.getName());
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
return null;
}
public ClassNode getClassNodeFromFile(File file, String clsName) {
JadxDecompiler d = new JadxDecompiler();
JadxDecompiler d = new JadxDecompiler(getArgs());
try {
d.loadFile(file);
} catch (JadxException e) {
e.printStackTrace();
fail(e.getMessage());
}
RootNode root = JadxInternalAccess.getRoot(d);
......@@ -83,11 +89,11 @@ public abstract class IntegrationTest extends TestUtils {
assertNotNull("Class not found: " + clsName, cls);
assertEquals(cls.getFullName(), clsName);
cls.load();
for (IDexTreeVisitor visitor : getPasses()) {
DepthTraversal.visit(visitor, cls);
if (unloadCls) {
decompile(d, cls);
} else {
decompileWithoutUnload(d, cls);
}
// don't unload class
System.out.println("-----------------------------------------------------------");
System.out.println(cls.getCode());
......@@ -99,6 +105,26 @@ public abstract class IntegrationTest extends TestUtils {
return cls;
}
private void decompile(JadxDecompiler jadx, ClassNode cls) {
List<IDexTreeVisitor> passes = Jadx.getPassesList(jadx.getArgs(), new File(outDir));
ProcessClass.process(cls, passes, new CodeGen(jadx.getArgs()));
}
private void decompileWithoutUnload(JadxDecompiler d, ClassNode cls) {
cls.load();
List<IDexTreeVisitor> passes = Jadx.getPassesList(d.getArgs(), new File(outDir));
for (IDexTreeVisitor visitor : passes) {
DepthTraversal.visit(visitor, cls);
}
try {
new CodeGen(d.getArgs()).visit(cls);
} catch (CodegenException e) {
e.printStackTrace();
fail(e.getMessage());
}
// don't unload class
}
private static void checkCode(ClassNode cls) {
assertTrue("Inconsistent cls: " + cls,
!cls.contains(AFlag.INCONSISTENT_CODE) && !cls.contains(AType.JADX_ERROR));
......@@ -109,8 +135,8 @@ public abstract class IntegrationTest extends TestUtils {
assertThat(cls.getCode().toString(), not(containsString("inconsistent")));
}
protected List<IDexTreeVisitor> getPasses() {
return Jadx.getPassesList(new DefaultJadxArgs() {
private DefaultJadxArgs getArgs() {
return new DefaultJadxArgs() {
@Override
public boolean isCFGOutput() {
return outputCFG;
......@@ -140,7 +166,7 @@ public abstract class IntegrationTest extends TestUtils {
public boolean isSkipResources() {
return true;
}
}, new File(outDir));
};
}
private void runAutoCheck(String clsName) {
......@@ -363,6 +389,10 @@ public abstract class IntegrationTest extends TestUtils {
this.compile = false;
}
protected void dontUnloadClass() {
this.unloadCls = false;
}
// Use only for debug purpose
@Deprecated
protected void setOutputCFG() {
......
......@@ -31,6 +31,7 @@ public class TestDuplicateCast extends IntegrationTest {
@Test
public void test() {
dontUnloadClass();
ClassNode cls = getClassNode(TestCls.class);
MethodNode mth = getMethod(cls, "method");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment