Commit 10de4ff4 authored by Skylot's avatar Skylot

core: process dependant classes before code generation

parent eed65421
...@@ -2,6 +2,7 @@ package jadx.api; ...@@ -2,6 +2,7 @@ package jadx.api;
import jadx.core.Jadx; import jadx.core.Jadx;
import jadx.core.ProcessClass; import jadx.core.ProcessClass;
import jadx.core.codegen.CodeGen;
import jadx.core.codegen.CodeWriter; import jadx.core.codegen.CodeWriter;
import jadx.core.deobf.DefaultDeobfuscator; import jadx.core.deobf.DefaultDeobfuscator;
import jadx.core.deobf.Deobfuscator; import jadx.core.deobf.Deobfuscator;
...@@ -57,6 +58,8 @@ public final class JadxDecompiler { ...@@ -57,6 +58,8 @@ public final class JadxDecompiler {
private RootNode root; private RootNode root;
private List<IDexTreeVisitor> passes; private List<IDexTreeVisitor> passes;
private CodeGen codeGen;
private List<JavaClass> classes; private List<JavaClass> classes;
private List<ResourceFile> resources; private List<ResourceFile> resources;
...@@ -83,6 +86,7 @@ public final class JadxDecompiler { ...@@ -83,6 +86,7 @@ public final class JadxDecompiler {
outDir = new DefaultJadxArgs().getOutDir(); outDir = new DefaultJadxArgs().getOutDir();
} }
this.passes = Jadx.getPassesList(args, outDir); this.passes = Jadx.getPassesList(args, outDir);
this.codeGen = new CodeGen(args);
} }
void reset() { void reset() {
...@@ -305,7 +309,7 @@ public final class JadxDecompiler { ...@@ -305,7 +309,7 @@ public final class JadxDecompiler {
} }
void processClass(ClassNode cls) { void processClass(ClassNode cls) {
ProcessClass.process(cls, passes); ProcessClass.process(cls, passes, codeGen);
} }
RootNode getRoot() { RootNode getRoot() {
...@@ -331,6 +335,10 @@ public final class JadxDecompiler { ...@@ -331,6 +335,10 @@ public final class JadxDecompiler {
return null; return null;
} }
public IJadxArgs getArgs() {
return args;
}
@Override @Override
public String toString() { public String toString() {
return "jadx decompiler " + getVersion(); return "jadx decompiler " + getVersion();
......
package jadx.core; package jadx.core;
import jadx.api.IJadxArgs; import jadx.api.IJadxArgs;
import jadx.core.codegen.CodeGen;
import jadx.core.dex.visitors.ClassModifier; import jadx.core.dex.visitors.ClassModifier;
import jadx.core.dex.visitors.CodeShrinker; import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.dex.visitors.ConstInlineVisitor; import jadx.core.dex.visitors.ConstInlineVisitor;
import jadx.core.dex.visitors.DebugInfoVisitor; import jadx.core.dex.visitors.DebugInfoVisitor;
import jadx.core.dex.visitors.DependencyCollector;
import jadx.core.dex.visitors.DotGraphVisitor; import jadx.core.dex.visitors.DotGraphVisitor;
import jadx.core.dex.visitors.EnumVisitor; import jadx.core.dex.visitors.EnumVisitor;
import jadx.core.dex.visitors.FallbackModeVisitor; import jadx.core.dex.visitors.FallbackModeVisitor;
...@@ -104,8 +104,9 @@ public class Jadx { ...@@ -104,8 +104,9 @@ public class Jadx {
passes.add(new PrepareForCodeGen()); passes.add(new PrepareForCodeGen());
passes.add(new LoopRegionVisitor()); passes.add(new LoopRegionVisitor());
passes.add(new ProcessVariables()); passes.add(new ProcessVariables());
passes.add(new DependencyCollector());
} }
passes.add(new CodeGen(args));
return passes; return passes;
} }
......
package jadx.core; package jadx.core;
import jadx.core.codegen.CodeGen;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.visitors.DepthTraversal; import jadx.core.dex.visitors.DepthTraversal;
import jadx.core.dex.visitors.IDexTreeVisitor; import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.utils.ErrorsCounter;
import java.util.List; import java.util.List;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import static jadx.core.dex.nodes.ProcessState.GENERATED;
import static jadx.core.dex.nodes.ProcessState.NOT_LOADED;
import static jadx.core.dex.nodes.ProcessState.PROCESSED;
import static jadx.core.dex.nodes.ProcessState.STARTED;
import static jadx.core.dex.nodes.ProcessState.UNLOADED;
public final class ProcessClass { public final class ProcessClass {
private static final Logger LOG = LoggerFactory.getLogger(ProcessClass.class); private static final Logger LOG = LoggerFactory.getLogger(ProcessClass.class);
private ProcessClass() { private ProcessClass() {
} }
public static void process(ClassNode cls, List<IDexTreeVisitor> passes) { public static void process(ClassNode cls, List<IDexTreeVisitor> passes, @Nullable CodeGen codeGen) {
try { synchronized (cls) {
cls.load(); try {
for (IDexTreeVisitor visitor : passes) { if (cls.getState() == NOT_LOADED) {
DepthTraversal.visit(visitor, cls); cls.load();
cls.setState(STARTED);
for (IDexTreeVisitor visitor : passes) {
DepthTraversal.visit(visitor, cls);
}
for (ClassNode clsNode : cls.getDependencies()) {
process(clsNode, passes, null);
}
cls.setState(PROCESSED);
}
if (cls.getState() == PROCESSED && codeGen != null) {
codeGen.visit(cls);
cls.setState(GENERATED);
}
} catch (Exception e) {
ErrorsCounter.classError(cls, e.getClass().getSimpleName(), e);
} finally {
if (cls.getState() == GENERATED) {
cls.unload();
cls.setState(UNLOADED);
}
} }
} catch (Exception e) {
LOG.error("Class process exception: {}", cls, e);
} finally {
cls.unload();
} }
} }
} }
...@@ -24,9 +24,11 @@ import jadx.core.utils.exceptions.JadxRuntimeException; ...@@ -24,9 +24,11 @@ import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import org.jetbrains.annotations.TestOnly; import org.jetbrains.annotations.TestOnly;
import org.slf4j.Logger; import org.slf4j.Logger;
...@@ -58,6 +60,9 @@ public class ClassNode extends LineAttrNode implements ILoadable { ...@@ -58,6 +60,9 @@ public class ClassNode extends LineAttrNode implements ILoadable {
// store parent for inner classes or 'this' otherwise // store parent for inner classes or 'this' otherwise
private ClassNode parentClass; private ClassNode parentClass;
private ProcessState state = ProcessState.NOT_LOADED;
private final Set<ClassNode> dependencies = new HashSet<ClassNode>();
public ClassNode(DexNode dex, ClassDef cls) throws DecodeException { public ClassNode(DexNode dex, ClassDef cls) throws DecodeException {
this.dex = dex; this.dex = dex;
this.clsInfo = ClassInfo.fromDex(dex, cls.getTypeIndex()); this.clsInfo = ClassInfo.fromDex(dex, cls.getTypeIndex());
...@@ -452,6 +457,18 @@ public class ClassNode extends LineAttrNode implements ILoadable { ...@@ -452,6 +457,18 @@ public class ClassNode extends LineAttrNode implements ILoadable {
return code; return code;
} }
public ProcessState getState() {
return state;
}
public void setState(ProcessState state) {
this.state = state;
}
public Set<ClassNode> getDependencies() {
return dependencies;
}
@Override @Override
public String toString() { public String toString() {
return getFullName(); return getFullName();
......
package jadx.core.dex.nodes;
public enum ProcessState {
NOT_LOADED,
STARTED,
PROCESSED,
GENERATED,
UNLOADED
}
package jadx.core.dex.visitors;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.exceptions.JadxException;
import java.util.Set;
public class DependencyCollector extends AbstractVisitor {
@Override
public boolean visit(ClassNode cls) throws JadxException {
DexNode dex = cls.dex();
Set<ClassNode> depList = cls.getDependencies();
processClass(cls, dex, depList);
for (ClassNode inner : cls.getInnerClasses()) {
processClass(inner, dex, depList);
}
depList.remove(cls);
return false;
}
private static void processClass(ClassNode cls, DexNode dex, Set<ClassNode> depList) {
addDep(dex, depList, cls.getSuperClass());
for (ClassInfo clsInfo : cls.getInterfaces()) {
addDep(dex, depList, clsInfo);
}
for (FieldNode fieldNode : cls.getFields()) {
addDep(dex, depList, fieldNode.getType());
}
// TODO: process annotations and generics
for (MethodNode methodNode : cls.getMethods()) {
if (methodNode.isNoCode()) {
continue;
}
processMethod(dex, depList, methodNode);
}
}
private static void processMethod(DexNode dex, Set<ClassNode> depList, MethodNode methodNode) {
addDep(dex, depList, methodNode.getParentClass());
addDep(dex, depList, methodNode.getReturnType());
for (ArgType arg : methodNode.getMethodInfo().getArgumentsTypes()) {
addDep(dex, depList, arg);
}
for (BlockNode block : methodNode.getBasicBlocks()) {
for (InsnNode insnNode : block.getInstructions()) {
processInsn(dex, depList, insnNode);
}
}
}
// TODO: add custom instructions processing
private static void processInsn(DexNode dex, Set<ClassNode> depList, InsnNode insnNode) {
RegisterArg result = insnNode.getResult();
if (result != null) {
addDep(dex, depList, result.getType());
}
for (InsnArg arg : insnNode.getArguments()) {
if (arg.isInsnWrap()) {
processInsn(dex, depList, ((InsnWrapArg) arg).getWrapInsn());
} else {
addDep(dex, depList, arg.getType());
}
}
}
private static void addDep(DexNode dex, Set<ClassNode> depList, ArgType type) {
if (type != null) {
if (type.isObject()) {
addDep(dex, depList, ClassInfo.fromName(type.getObject()));
ArgType[] genericTypes = type.getGenericTypes();
if (type.isGeneric() && genericTypes != null) {
for (ArgType argType : genericTypes) {
addDep(dex, depList, argType);
}
}
} else if (type.isArray()) {
addDep(dex, depList, type.getArrayRootElement());
}
}
}
private static void addDep(DexNode dex, Set<ClassNode> depList, ClassInfo clsInfo) {
if (clsInfo != null) {
ClassNode node = dex.resolveClass(clsInfo);
if (node != null) {
depList.add(node);
}
}
}
private static void addDep(DexNode dex, Set<ClassNode> depList, ClassNode clsNode) {
if (clsNode != null) {
depList.add(clsNode);
}
}
}
...@@ -4,6 +4,8 @@ import jadx.api.DefaultJadxArgs; ...@@ -4,6 +4,8 @@ import jadx.api.DefaultJadxArgs;
import jadx.api.JadxDecompiler; import jadx.api.JadxDecompiler;
import jadx.api.JadxInternalAccess; import jadx.api.JadxInternalAccess;
import jadx.core.Jadx; import jadx.core.Jadx;
import jadx.core.ProcessClass;
import jadx.core.codegen.CodeGen;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.AType;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
...@@ -11,6 +13,7 @@ import jadx.core.dex.nodes.MethodNode; ...@@ -11,6 +13,7 @@ import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode; import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.DepthTraversal; import jadx.core.dex.visitors.DepthTraversal;
import jadx.core.dex.visitors.IDexTreeVisitor; import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.utils.exceptions.CodegenException;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
import jadx.core.utils.files.FileUtils; import jadx.core.utils.files.FileUtils;
import jadx.tests.api.compiler.DynamicCompiler; import jadx.tests.api.compiler.DynamicCompiler;
...@@ -51,6 +54,7 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -51,6 +54,7 @@ public abstract class IntegrationTest extends TestUtils {
protected boolean isFallback = false; protected boolean isFallback = false;
protected boolean deleteTmpFiles = true; protected boolean deleteTmpFiles = true;
protected boolean withDebugInfo = true; protected boolean withDebugInfo = true;
protected boolean unloadCls = true;
protected Map<Integer, String> resMap = Collections.emptyMap(); protected Map<Integer, String> resMap = Collections.emptyMap();
...@@ -64,16 +68,18 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -64,16 +68,18 @@ public abstract class IntegrationTest extends TestUtils {
File jar = getJarForClass(clazz); File jar = getJarForClass(clazz);
return getClassNodeFromFile(jar, clazz.getName()); return getClassNodeFromFile(jar, clazz.getName());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace();
fail(e.getMessage()); fail(e.getMessage());
} }
return null; return null;
} }
public ClassNode getClassNodeFromFile(File file, String clsName) { public ClassNode getClassNodeFromFile(File file, String clsName) {
JadxDecompiler d = new JadxDecompiler(); JadxDecompiler d = new JadxDecompiler(getArgs());
try { try {
d.loadFile(file); d.loadFile(file);
} catch (JadxException e) { } catch (JadxException e) {
e.printStackTrace();
fail(e.getMessage()); fail(e.getMessage());
} }
RootNode root = JadxInternalAccess.getRoot(d); RootNode root = JadxInternalAccess.getRoot(d);
...@@ -83,11 +89,11 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -83,11 +89,11 @@ public abstract class IntegrationTest extends TestUtils {
assertNotNull("Class not found: " + clsName, cls); assertNotNull("Class not found: " + clsName, cls);
assertEquals(cls.getFullName(), clsName); assertEquals(cls.getFullName(), clsName);
cls.load(); if (unloadCls) {
for (IDexTreeVisitor visitor : getPasses()) { decompile(d, cls);
DepthTraversal.visit(visitor, cls); } else {
decompileWithoutUnload(d, cls);
} }
// don't unload class
System.out.println("-----------------------------------------------------------"); System.out.println("-----------------------------------------------------------");
System.out.println(cls.getCode()); System.out.println(cls.getCode());
...@@ -99,6 +105,26 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -99,6 +105,26 @@ public abstract class IntegrationTest extends TestUtils {
return cls; return cls;
} }
private void decompile(JadxDecompiler jadx, ClassNode cls) {
List<IDexTreeVisitor> passes = Jadx.getPassesList(jadx.getArgs(), new File(outDir));
ProcessClass.process(cls, passes, new CodeGen(jadx.getArgs()));
}
private void decompileWithoutUnload(JadxDecompiler d, ClassNode cls) {
cls.load();
List<IDexTreeVisitor> passes = Jadx.getPassesList(d.getArgs(), new File(outDir));
for (IDexTreeVisitor visitor : passes) {
DepthTraversal.visit(visitor, cls);
}
try {
new CodeGen(d.getArgs()).visit(cls);
} catch (CodegenException e) {
e.printStackTrace();
fail(e.getMessage());
}
// don't unload class
}
private static void checkCode(ClassNode cls) { private static void checkCode(ClassNode cls) {
assertTrue("Inconsistent cls: " + cls, assertTrue("Inconsistent cls: " + cls,
!cls.contains(AFlag.INCONSISTENT_CODE) && !cls.contains(AType.JADX_ERROR)); !cls.contains(AFlag.INCONSISTENT_CODE) && !cls.contains(AType.JADX_ERROR));
...@@ -109,8 +135,8 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -109,8 +135,8 @@ public abstract class IntegrationTest extends TestUtils {
assertThat(cls.getCode().toString(), not(containsString("inconsistent"))); assertThat(cls.getCode().toString(), not(containsString("inconsistent")));
} }
protected List<IDexTreeVisitor> getPasses() { private DefaultJadxArgs getArgs() {
return Jadx.getPassesList(new DefaultJadxArgs() { return new DefaultJadxArgs() {
@Override @Override
public boolean isCFGOutput() { public boolean isCFGOutput() {
return outputCFG; return outputCFG;
...@@ -140,7 +166,7 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -140,7 +166,7 @@ public abstract class IntegrationTest extends TestUtils {
public boolean isSkipResources() { public boolean isSkipResources() {
return true; return true;
} }
}, new File(outDir)); };
} }
private void runAutoCheck(String clsName) { private void runAutoCheck(String clsName) {
...@@ -363,6 +389,10 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -363,6 +389,10 @@ public abstract class IntegrationTest extends TestUtils {
this.compile = false; this.compile = false;
} }
protected void dontUnloadClass() {
this.unloadCls = false;
}
// Use only for debug purpose // Use only for debug purpose
@Deprecated @Deprecated
protected void setOutputCFG() { protected void setOutputCFG() {
......
...@@ -31,6 +31,7 @@ public class TestDuplicateCast extends IntegrationTest { ...@@ -31,6 +31,7 @@ public class TestDuplicateCast extends IntegrationTest {
@Test @Test
public void test() { public void test() {
dontUnloadClass();
ClassNode cls = getClassNode(TestCls.class); ClassNode cls = getClassNode(TestCls.class);
MethodNode mth = getMethod(cls, "method"); MethodNode mth = getMethod(cls, "method");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment