Commit f846df53 authored by Skylot's avatar Skylot

fix: rename field if collide with any root package (#647)

parent 4a39af7c
...@@ -5,6 +5,8 @@ import java.util.HashSet; ...@@ -5,6 +5,8 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import org.jetbrains.annotations.Nullable;
import jadx.api.JadxArgs; import jadx.api.JadxArgs;
import jadx.core.Consts; import jadx.core.Consts;
import jadx.core.deobf.Deobfuscator; import jadx.core.deobf.Deobfuscator;
...@@ -65,6 +67,9 @@ public class RenameVisitor extends AbstractVisitor { ...@@ -65,6 +67,9 @@ public class RenameVisitor extends AbstractVisitor {
} }
} }
} }
if (args.isRenameValid()) {
checkFieldsCollisionWithRootPackage(classes);
}
} }
private void checkClassName(ClassNode cls, JadxArgs args) { private void checkClassName(ClassNode cls, JadxArgs args) {
...@@ -134,4 +139,42 @@ public class RenameVisitor extends AbstractVisitor { ...@@ -134,4 +139,42 @@ public class RenameVisitor extends AbstractVisitor {
} }
} }
} }
private void checkFieldsCollisionWithRootPackage(List<ClassNode> classes) {
Set<String> rootPkgs = collectRootPkgs(classes);
for (ClassNode cls : classes) {
for (FieldNode field : cls.getFields()) {
if (rootPkgs.contains(field.getAlias())) {
deobfuscator.forceRenameField(field);
}
}
}
}
private static Set<String> collectRootPkgs(List<ClassNode> classes) {
Set<String> fullPkgs = new HashSet<>();
for (ClassNode cls : classes) {
fullPkgs.add(cls.getAlias().getPackage());
}
Set<String> rootPkgs = new HashSet<>();
for (String pkg : fullPkgs) {
String rootPkg = getRootPkg(pkg);
if (rootPkg != null) {
rootPkgs.add(rootPkg);
}
}
return rootPkgs;
}
@Nullable
private static String getRootPkg(String pkg) {
if (pkg.isEmpty()) {
return null;
}
int dotPos = pkg.indexOf('.');
if (dotPos < 0) {
return pkg;
}
return pkg.substring(0, dotPos);
}
} }
...@@ -124,10 +124,20 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -124,10 +124,20 @@ public abstract class IntegrationTest extends TestUtils {
assertThat("Class not found: " + clsName, cls, notNullValue()); assertThat("Class not found: " + clsName, cls, notNullValue());
assertThat(clsName, is(cls.getClassInfo().getFullName())); assertThat(clsName, is(cls.getClassInfo().getFullName()));
decompileAndCheckCls(d, cls); decompileAndCheck(d, Collections.singletonList(cls));
return cls; return cls;
} }
public ClassNode searchCls(List<ClassNode> list, String fullClsName) {
for (ClassNode cls : list) {
if (cls.getClassInfo().getFullName().equals(fullClsName)) {
return cls;
}
}
fail("Class not found by name " + fullClsName + " in list: " + list);
return null;
}
protected JadxDecompiler loadFiles(List<File> inputFiles) { protected JadxDecompiler loadFiles(List<File> inputFiles) {
JadxDecompiler d = null; JadxDecompiler d = null;
try { try {
...@@ -137,26 +147,29 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -137,26 +147,29 @@ public abstract class IntegrationTest extends TestUtils {
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
fail(e.getMessage()); fail(e.getMessage());
return null;
} }
RootNode root = JadxInternalAccess.getRoot(d); RootNode root = JadxInternalAccess.getRoot(d);
insertResources(root); insertResources(root);
return d; return d;
} }
protected void decompileAndCheckCls(JadxDecompiler d, ClassNode cls) { protected void decompileAndCheck(JadxDecompiler d, List<ClassNode> clsList) {
if (unloadCls) { if (unloadCls) {
decompile(d, cls); clsList.forEach(cls -> decompile(d, cls));
} else { } else {
decompileWithoutUnload(d, cls); clsList.forEach(cls -> decompileWithoutUnload(d, cls));
} }
System.out.println("-----------------------------------------------------------"); for (ClassNode cls : clsList) {
System.out.println(cls.getCode()); System.out.println("-----------------------------------------------------------");
System.out.println(cls.getCode());
}
System.out.println("-----------------------------------------------------------"); System.out.println("-----------------------------------------------------------");
checkCode(cls); clsList.forEach(IntegrationTest::checkCode);
compile(cls); compile(clsList);
runAutoCheck(cls.getClassInfo().getFullName()); clsList.forEach(this::runAutoCheck);
} }
private void insertResources(RootNode root) { private void insertResources(RootNode root) {
...@@ -221,7 +234,8 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -221,7 +234,8 @@ public abstract class IntegrationTest extends TestUtils {
return false; return false;
} }
private void runAutoCheck(String clsName) { private void runAutoCheck(ClassNode cls) {
String clsName = cls.getClassInfo().getFullName();
try { try {
// run 'check' method from original class // run 'check' method from original class
Class<?> origCls; Class<?> origCls;
...@@ -252,7 +266,7 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -252,7 +266,7 @@ public abstract class IntegrationTest extends TestUtils {
// run 'check' method from decompiled class // run 'check' method from decompiled class
if (compile) { if (compile) {
try { try {
limitExecTime(() -> invoke("check")); limitExecTime(() -> invoke(cls, "check"));
} catch (Exception e) { } catch (Exception e) {
rethrow("Decompiled check failed", e); rethrow("Decompiled check failed", e);
} }
...@@ -306,11 +320,15 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -306,11 +320,15 @@ public abstract class IntegrationTest extends TestUtils {
} }
void compile(ClassNode cls) { void compile(ClassNode cls) {
compile(Collections.singletonList(cls));
}
void compile(List<ClassNode> clsList) {
if (!compile) { if (!compile) {
return; return;
} }
try { try {
dynamicCompiler = new DynamicCompiler(cls); dynamicCompiler = new DynamicCompiler(clsList);
boolean result = dynamicCompiler.compile(); boolean result = dynamicCompiler.compile();
assertTrue(result, "Compilation failed"); assertTrue(result, "Compilation failed");
System.out.println("Compilation: PASSED"); System.out.println("Compilation: PASSED");
...@@ -319,30 +337,13 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -319,30 +337,13 @@ public abstract class IntegrationTest extends TestUtils {
} }
} }
public Object invoke(String method) throws Exception { public Object invoke(ClassNode cls, String method) throws Exception {
return invoke(method, new Class<?>[0]); return invoke(cls, method, new Class<?>[0]);
}
public Object invoke(String method, Class<?>[] types, Object... args) throws Exception {
Method mth = getReflectMethod(method, types);
return invoke(mth, args);
}
public Method getReflectMethod(String method, Class<?>... types) {
assertNotNull(dynamicCompiler, "dynamicCompiler not ready");
try {
return dynamicCompiler.getMethod(method, types);
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
return null;
} }
public Object invoke(Method mth, Object... args) throws Exception { public Object invoke(ClassNode cls, String methodName, Class<?>[] types, Object... args) throws Exception {
assertNotNull(dynamicCompiler, "dynamicCompiler not ready"); assertNotNull(dynamicCompiler, "dynamicCompiler not ready");
assertNotNull(mth, "unknown method"); return dynamicCompiler.invoke(cls, methodName, types, args);
return dynamicCompiler.invoke(mth, args);
} }
public File getJarForClass(Class<?> cls) throws IOException { public File getJarForClass(Class<?> cls) throws IOException {
......
...@@ -56,9 +56,7 @@ public abstract class SmaliTest extends IntegrationTest { ...@@ -56,9 +56,7 @@ public abstract class SmaliTest extends IntegrationTest {
JadxDecompiler d = loadFiles(Collections.singletonList(outDex)); JadxDecompiler d = loadFiles(Collections.singletonList(outDex));
RootNode root = JadxInternalAccess.getRoot(d); RootNode root = JadxInternalAccess.getRoot(d);
List<ClassNode> classes = root.getClasses(false); List<ClassNode> classes = root.getClasses(false);
for (ClassNode cls : classes) { decompileAndCheck(d, classes);
decompileAndCheckCls(d, cls);
}
return classes; return classes;
} }
......
...@@ -2,6 +2,7 @@ package jadx.tests.api.compiler; ...@@ -2,6 +2,7 @@ package jadx.tests.api.compiler;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import javax.tools.JavaCompiler; import javax.tools.JavaCompiler;
...@@ -9,31 +10,28 @@ import javax.tools.JavaFileManager; ...@@ -9,31 +10,28 @@ import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject; import javax.tools.JavaFileObject;
import javax.tools.ToolProvider; import javax.tools.ToolProvider;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import static javax.tools.JavaCompiler.CompilationTask; import static javax.tools.JavaCompiler.CompilationTask;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;
public class DynamicCompiler { public class DynamicCompiler {
private static final Logger LOG = LoggerFactory.getLogger(DynamicCompiler.class); private static final Logger LOG = LoggerFactory.getLogger(DynamicCompiler.class);
private final ClassNode clsNode; private final List<ClassNode> clsNodeList;
private JavaFileManager fileManager; private JavaFileManager fileManager;
private Object instance; public DynamicCompiler(List<ClassNode> clsNodeList) {
this.clsNodeList = clsNodeList;
public DynamicCompiler(ClassNode clsNode) {
this.clsNode = clsNode;
} }
public boolean compile() throws Exception { public boolean compile() {
String fullName = clsNode.getFullName();
String code = clsNode.getCode().toString();
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
if (compiler == null) { if (compiler == null) {
LOG.error("Can not find compiler, please use JDK instead"); LOG.error("Can not find compiler, please use JDK instead");
...@@ -41,8 +39,10 @@ public class DynamicCompiler { ...@@ -41,8 +39,10 @@ public class DynamicCompiler {
} }
fileManager = new ClassFileManager(compiler.getStandardFileManager(null, null, null)); fileManager = new ClassFileManager(compiler.getStandardFileManager(null, null, null));
List<JavaFileObject> jFiles = new ArrayList<>(1); List<JavaFileObject> jFiles = new ArrayList<>(clsNodeList.size());
jFiles.add(new CharSequenceJavaFileObject(fullName, code)); for (ClassNode clsNode : clsNodeList) {
jFiles.add(new CharSequenceJavaFileObject(clsNode.getFullName(), clsNode.getCode().toString()));
}
CompilationTask compilerTask = compiler.getTask(null, fileManager, null, null, null, jFiles); CompilationTask compilerTask = compiler.getTask(null, fileManager, null, null, null, jFiles);
return Boolean.TRUE.equals(compilerTask.call()); return Boolean.TRUE.equals(compilerTask.call());
...@@ -52,27 +52,29 @@ public class DynamicCompiler { ...@@ -52,27 +52,29 @@ public class DynamicCompiler {
return fileManager.getClassLoader(null); return fileManager.getClassLoader(null);
} }
private void makeInstance() throws Exception { public Object makeInstance(ClassNode cls) throws Exception {
String fullName = clsNode.getFullName(); String fullName = cls.getFullName();
instance = getClassLoader().loadClass(fullName).getConstructor().newInstance(); return getClassLoader().loadClass(fullName).getConstructor().newInstance();
} }
private Object getInstance() throws Exception { @NotNull
if (instance == null) { public Method getMethod(Object inst, String methodName, Class<?>[] types) throws Exception {
makeInstance();
}
return instance;
}
public Method getMethod(String method, Class<?>[] types) throws Exception {
for (Class<?> type : types) { for (Class<?> type : types) {
checkType(type); checkType(type);
} }
return getInstance().getClass().getMethod(method, types); return inst.getClass().getMethod(methodName, types);
} }
public Object invoke(Method mth, Object... args) throws Exception { public Object invoke(ClassNode cls, String methodName, Class<?>[] types, Object[] args) {
return mth.invoke(getInstance(), args); try {
Object inst = makeInstance(cls);
Method reflMth = getMethod(inst, methodName, types);
assertNotNull(reflMth, "Failed to get method " + methodName + '(' + Arrays.toString(types) + ')');
return reflMth.invoke(inst, args);
} catch (Exception e) {
fail(e.getMessage(), e);
return null;
}
} }
private Class<?> checkType(Class<?> type) throws ClassNotFoundException { private Class<?> checkType(Class<?> type) throws ClassNotFoundException {
......
package jadx.tests.integration.loops; package jadx.tests.integration.loops;
import java.lang.reflect.Method;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
...@@ -29,6 +27,12 @@ public class TestBreakWithLabel extends IntegrationTest { ...@@ -29,6 +27,12 @@ public class TestBreakWithLabel extends IntegrationTest {
System.out.println("found: " + found); System.out.println("found: " + found);
return found; return found;
} }
public void check() {
int[][] testArray = { { 1, 2 }, { 3, 4 } };
assertTrue(test(testArray, 3));
assertFalse(test(testArray, 5));
}
} }
@Test @Test
...@@ -38,10 +42,5 @@ public class TestBreakWithLabel extends IntegrationTest { ...@@ -38,10 +42,5 @@ public class TestBreakWithLabel extends IntegrationTest {
assertThat(code, containsOne("loop0:")); assertThat(code, containsOne("loop0:"));
assertThat(code, containsOne("break loop0;")); assertThat(code, containsOne("break loop0;"));
Method test = getReflectMethod("test", int[][].class, int.class);
int[][] testArray = { { 1, 2 }, { 3, 4 } };
assertTrue((Boolean) invoke(test, testArray, 3));
assertFalse((Boolean) invoke(test, testArray, 5));
} }
} }
package jadx.tests.integration.names;
import java.util.List;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.not;
public class TestFieldCollideWithPackage extends SmaliTest {
//@formatter:off
/*
-----------------------------------------------------------
package first;
public class A {
public A first;
public second.A second;
public String test() {
return second.A.call(); // compiler treat 'second' as field name
}
}
-----------------------------------------------------------
package second;
public class A {
public static String call() {
return null;
}
}
-----------------------------------------------------------
*/
//@formatter:on
@Test
public void test() {
List<ClassNode> clsList = loadFromSmaliFiles();
ClassNode firstA = searchCls(clsList, "first.A");
String code = firstA.getCode().toString();
assertThat(code, containsString("second.A"));
// expect field to be renamed
assertThat(code, not(containsString("public second.A second;")));
}
@Test
public void testWithoutImports() {
getArgs().setUseImports(false);
loadFromSmaliFiles();
}
@Test
public void testWithDeobfuscation() {
enableDeobfuscation();
loadFromSmaliFiles();
}
}
.class public Lfirst/A;
.super Ljava/lang/Object;
.field public first:Lfirst/A;
.field public second:Lsecond/A;
.method public test()Ljava/lang/String;
.registers 2
invoke-static {}, Lsecond/A;->call()Ljava/lang/String;
move-result-object v0
return-object v0
.end method
.class public Lsecond/A;
.super Ljava/lang/Object;
.method static public call()Ljava/lang/String;
.registers 1
const v0, 0
return-object v0
.end method
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment