Commit f846df53 authored by Skylot's avatar Skylot

fix: rename field if collide with any root package (#647)

parent 4a39af7c
......@@ -5,6 +5,8 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.jetbrains.annotations.Nullable;
import jadx.api.JadxArgs;
import jadx.core.Consts;
import jadx.core.deobf.Deobfuscator;
......@@ -65,6 +67,9 @@ public class RenameVisitor extends AbstractVisitor {
}
}
}
if (args.isRenameValid()) {
checkFieldsCollisionWithRootPackage(classes);
}
}
private void checkClassName(ClassNode cls, JadxArgs args) {
......@@ -134,4 +139,42 @@ public class RenameVisitor extends AbstractVisitor {
}
}
}
private void checkFieldsCollisionWithRootPackage(List<ClassNode> classes) {
Set<String> rootPkgs = collectRootPkgs(classes);
for (ClassNode cls : classes) {
for (FieldNode field : cls.getFields()) {
if (rootPkgs.contains(field.getAlias())) {
deobfuscator.forceRenameField(field);
}
}
}
}
private static Set<String> collectRootPkgs(List<ClassNode> classes) {
Set<String> fullPkgs = new HashSet<>();
for (ClassNode cls : classes) {
fullPkgs.add(cls.getAlias().getPackage());
}
Set<String> rootPkgs = new HashSet<>();
for (String pkg : fullPkgs) {
String rootPkg = getRootPkg(pkg);
if (rootPkg != null) {
rootPkgs.add(rootPkg);
}
}
return rootPkgs;
}
@Nullable
private static String getRootPkg(String pkg) {
if (pkg.isEmpty()) {
return null;
}
int dotPos = pkg.indexOf('.');
if (dotPos < 0) {
return pkg;
}
return pkg.substring(0, dotPos);
}
}
......@@ -124,10 +124,20 @@ public abstract class IntegrationTest extends TestUtils {
assertThat("Class not found: " + clsName, cls, notNullValue());
assertThat(clsName, is(cls.getClassInfo().getFullName()));
decompileAndCheckCls(d, cls);
decompileAndCheck(d, Collections.singletonList(cls));
return cls;
}
public ClassNode searchCls(List<ClassNode> list, String fullClsName) {
for (ClassNode cls : list) {
if (cls.getClassInfo().getFullName().equals(fullClsName)) {
return cls;
}
}
fail("Class not found by name " + fullClsName + " in list: " + list);
return null;
}
protected JadxDecompiler loadFiles(List<File> inputFiles) {
JadxDecompiler d = null;
try {
......@@ -137,26 +147,29 @@ public abstract class IntegrationTest extends TestUtils {
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
return null;
}
RootNode root = JadxInternalAccess.getRoot(d);
insertResources(root);
return d;
}
protected void decompileAndCheckCls(JadxDecompiler d, ClassNode cls) {
protected void decompileAndCheck(JadxDecompiler d, List<ClassNode> clsList) {
if (unloadCls) {
decompile(d, cls);
clsList.forEach(cls -> decompile(d, cls));
} else {
decompileWithoutUnload(d, cls);
clsList.forEach(cls -> decompileWithoutUnload(d, cls));
}
System.out.println("-----------------------------------------------------------");
System.out.println(cls.getCode());
for (ClassNode cls : clsList) {
System.out.println("-----------------------------------------------------------");
System.out.println(cls.getCode());
}
System.out.println("-----------------------------------------------------------");
checkCode(cls);
compile(cls);
runAutoCheck(cls.getClassInfo().getFullName());
clsList.forEach(IntegrationTest::checkCode);
compile(clsList);
clsList.forEach(this::runAutoCheck);
}
private void insertResources(RootNode root) {
......@@ -221,7 +234,8 @@ public abstract class IntegrationTest extends TestUtils {
return false;
}
private void runAutoCheck(String clsName) {
private void runAutoCheck(ClassNode cls) {
String clsName = cls.getClassInfo().getFullName();
try {
// run 'check' method from original class
Class<?> origCls;
......@@ -252,7 +266,7 @@ public abstract class IntegrationTest extends TestUtils {
// run 'check' method from decompiled class
if (compile) {
try {
limitExecTime(() -> invoke("check"));
limitExecTime(() -> invoke(cls, "check"));
} catch (Exception e) {
rethrow("Decompiled check failed", e);
}
......@@ -306,11 +320,15 @@ public abstract class IntegrationTest extends TestUtils {
}
void compile(ClassNode cls) {
compile(Collections.singletonList(cls));
}
void compile(List<ClassNode> clsList) {
if (!compile) {
return;
}
try {
dynamicCompiler = new DynamicCompiler(cls);
dynamicCompiler = new DynamicCompiler(clsList);
boolean result = dynamicCompiler.compile();
assertTrue(result, "Compilation failed");
System.out.println("Compilation: PASSED");
......@@ -319,30 +337,13 @@ public abstract class IntegrationTest extends TestUtils {
}
}
public Object invoke(String method) throws Exception {
return invoke(method, new Class<?>[0]);
}
public Object invoke(String method, Class<?>[] types, Object... args) throws Exception {
Method mth = getReflectMethod(method, types);
return invoke(mth, args);
}
public Method getReflectMethod(String method, Class<?>... types) {
assertNotNull(dynamicCompiler, "dynamicCompiler not ready");
try {
return dynamicCompiler.getMethod(method, types);
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
return null;
public Object invoke(ClassNode cls, String method) throws Exception {
return invoke(cls, method, new Class<?>[0]);
}
public Object invoke(Method mth, Object... args) throws Exception {
public Object invoke(ClassNode cls, String methodName, Class<?>[] types, Object... args) throws Exception {
assertNotNull(dynamicCompiler, "dynamicCompiler not ready");
assertNotNull(mth, "unknown method");
return dynamicCompiler.invoke(mth, args);
return dynamicCompiler.invoke(cls, methodName, types, args);
}
public File getJarForClass(Class<?> cls) throws IOException {
......
......@@ -56,9 +56,7 @@ public abstract class SmaliTest extends IntegrationTest {
JadxDecompiler d = loadFiles(Collections.singletonList(outDex));
RootNode root = JadxInternalAccess.getRoot(d);
List<ClassNode> classes = root.getClasses(false);
for (ClassNode cls : classes) {
decompileAndCheckCls(d, cls);
}
decompileAndCheck(d, classes);
return classes;
}
......
......@@ -2,6 +2,7 @@ package jadx.tests.api.compiler;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.tools.JavaCompiler;
......@@ -9,31 +10,28 @@ import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject;
import javax.tools.ToolProvider;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.dex.nodes.ClassNode;
import static javax.tools.JavaCompiler.CompilationTask;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;
public class DynamicCompiler {
private static final Logger LOG = LoggerFactory.getLogger(DynamicCompiler.class);
private final ClassNode clsNode;
private final List<ClassNode> clsNodeList;
private JavaFileManager fileManager;
private Object instance;
public DynamicCompiler(ClassNode clsNode) {
this.clsNode = clsNode;
public DynamicCompiler(List<ClassNode> clsNodeList) {
this.clsNodeList = clsNodeList;
}
public boolean compile() throws Exception {
String fullName = clsNode.getFullName();
String code = clsNode.getCode().toString();
public boolean compile() {
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
if (compiler == null) {
LOG.error("Can not find compiler, please use JDK instead");
......@@ -41,8 +39,10 @@ public class DynamicCompiler {
}
fileManager = new ClassFileManager(compiler.getStandardFileManager(null, null, null));
List<JavaFileObject> jFiles = new ArrayList<>(1);
jFiles.add(new CharSequenceJavaFileObject(fullName, code));
List<JavaFileObject> jFiles = new ArrayList<>(clsNodeList.size());
for (ClassNode clsNode : clsNodeList) {
jFiles.add(new CharSequenceJavaFileObject(clsNode.getFullName(), clsNode.getCode().toString()));
}
CompilationTask compilerTask = compiler.getTask(null, fileManager, null, null, null, jFiles);
return Boolean.TRUE.equals(compilerTask.call());
......@@ -52,27 +52,29 @@ public class DynamicCompiler {
return fileManager.getClassLoader(null);
}
private void makeInstance() throws Exception {
String fullName = clsNode.getFullName();
instance = getClassLoader().loadClass(fullName).getConstructor().newInstance();
public Object makeInstance(ClassNode cls) throws Exception {
String fullName = cls.getFullName();
return getClassLoader().loadClass(fullName).getConstructor().newInstance();
}
private Object getInstance() throws Exception {
if (instance == null) {
makeInstance();
}
return instance;
}
public Method getMethod(String method, Class<?>[] types) throws Exception {
@NotNull
public Method getMethod(Object inst, String methodName, Class<?>[] types) throws Exception {
for (Class<?> type : types) {
checkType(type);
}
return getInstance().getClass().getMethod(method, types);
return inst.getClass().getMethod(methodName, types);
}
public Object invoke(Method mth, Object... args) throws Exception {
return mth.invoke(getInstance(), args);
public Object invoke(ClassNode cls, String methodName, Class<?>[] types, Object[] args) {
try {
Object inst = makeInstance(cls);
Method reflMth = getMethod(inst, methodName, types);
assertNotNull(reflMth, "Failed to get method " + methodName + '(' + Arrays.toString(types) + ')');
return reflMth.invoke(inst, args);
} catch (Exception e) {
fail(e.getMessage(), e);
return null;
}
}
private Class<?> checkType(Class<?> type) throws ClassNotFoundException {
......
package jadx.tests.integration.loops;
import java.lang.reflect.Method;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
......@@ -29,6 +27,12 @@ public class TestBreakWithLabel extends IntegrationTest {
System.out.println("found: " + found);
return found;
}
public void check() {
int[][] testArray = { { 1, 2 }, { 3, 4 } };
assertTrue(test(testArray, 3));
assertFalse(test(testArray, 5));
}
}
@Test
......@@ -38,10 +42,5 @@ public class TestBreakWithLabel extends IntegrationTest {
assertThat(code, containsOne("loop0:"));
assertThat(code, containsOne("break loop0;"));
Method test = getReflectMethod("test", int[][].class, int.class);
int[][] testArray = { { 1, 2 }, { 3, 4 } };
assertTrue((Boolean) invoke(test, testArray, 3));
assertFalse((Boolean) invoke(test, testArray, 5));
}
}
package jadx.tests.integration.names;
import java.util.List;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.not;
public class TestFieldCollideWithPackage extends SmaliTest {
//@formatter:off
/*
-----------------------------------------------------------
package first;
public class A {
public A first;
public second.A second;
public String test() {
return second.A.call(); // compiler treat 'second' as field name
}
}
-----------------------------------------------------------
package second;
public class A {
public static String call() {
return null;
}
}
-----------------------------------------------------------
*/
//@formatter:on
@Test
public void test() {
List<ClassNode> clsList = loadFromSmaliFiles();
ClassNode firstA = searchCls(clsList, "first.A");
String code = firstA.getCode().toString();
assertThat(code, containsString("second.A"));
// expect field to be renamed
assertThat(code, not(containsString("public second.A second;")));
}
@Test
public void testWithoutImports() {
getArgs().setUseImports(false);
loadFromSmaliFiles();
}
@Test
public void testWithDeobfuscation() {
enableDeobfuscation();
loadFromSmaliFiles();
}
}
.class public Lfirst/A;
.super Ljava/lang/Object;
.field public first:Lfirst/A;
.field public second:Lsecond/A;
.method public test()Ljava/lang/String;
.registers 2
invoke-static {}, Lsecond/A;->call()Ljava/lang/String;
move-result-object v0
return-object v0
.end method
.class public Lsecond/A;
.super Ljava/lang/Object;
.method static public call()Ljava/lang/String;
.registers 1
const v0, 0
return-object v0
.end method
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment