Commit c552fb85 authored by Skylot's avatar Skylot

core tests: replace several classes in dynamic class loader, add additional checks

parent 8a4ec47b
...@@ -18,7 +18,9 @@ import jadx.tests.api.utils.TestUtils; ...@@ -18,7 +18,9 @@ import jadx.tests.api.utils.TestUtils;
import java.io.File; import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.URL; import java.net.URL;
import java.util.ArrayList; import java.util.ArrayList;
...@@ -71,8 +73,11 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -71,8 +73,11 @@ public abstract class IntegrationTest extends TestUtils {
} }
// don't unload class // don't unload class
System.out.println(cls.getCode());
checkCode(cls); checkCode(cls);
compile(cls); compile(cls);
runAutoCheck(clsName);
return cls; return cls;
} }
...@@ -102,9 +107,65 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -102,9 +107,65 @@ public abstract class IntegrationTest extends TestUtils {
public boolean isFallbackMode() { public boolean isFallbackMode() {
return isFallback; return isFallback;
} }
@Override
public boolean isShowInconsistentCode() {
return true;
}
}, new File(outDir)); }, new File(outDir));
} }
private void runAutoCheck(String clsName) {
try {
// run 'check' method from original class
Class<?> origCls;
try {
origCls = Class.forName(clsName);
} catch (ClassNotFoundException e) {
// ignore
return;
}
Method checkMth;
try {
checkMth = origCls.getMethod("check");
} catch (NoSuchMethodException e) {
// ignore
return;
}
if (!checkMth.getReturnType().equals(void.class)
|| !Modifier.isPublic(checkMth.getModifiers())
|| Modifier.isStatic(checkMth.getModifiers())) {
fail("Wrong 'check' method");
return;
}
try {
checkMth.invoke(origCls.newInstance());
} catch (InvocationTargetException ie) {
rethrow("Java check failed", ie);
}
// run 'check' method from decompiled class
try {
invoke("check");
} catch (InvocationTargetException ie) {
rethrow("Decompiled check failed", ie);
}
} catch (Exception e) {
e.printStackTrace();
fail("Auto check exception: " + e.getMessage());
}
}
private void rethrow(String msg, InvocationTargetException ie) {
Throwable cause = ie.getCause();
if (cause instanceof AssertionError) {
System.err.println(msg);
throw ((AssertionError) cause);
} else {
cause.printStackTrace();
fail(msg + cause.getMessage());
}
}
protected MethodNode getMethod(ClassNode cls, String method) { protected MethodNode getMethod(ClassNode cls, String method) {
for (MethodNode mth : cls.getMethods()) { for (MethodNode mth : cls.getMethods()) {
if (mth.getName().equals(method)) { if (mth.getName().equals(method)) {
...@@ -133,7 +194,7 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -133,7 +194,7 @@ public abstract class IntegrationTest extends TestUtils {
return invoke(method, new Class[0]); return invoke(method, new Class[0]);
} }
public Object invoke(String method, Class[] types, Object... args) { public Object invoke(String method, Class[] types, Object... args) throws Exception {
Method mth = getReflectMethod(method, types); Method mth = getReflectMethod(method, types);
return invoke(mth, args); return invoke(mth, args);
} }
...@@ -149,16 +210,10 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -149,16 +210,10 @@ public abstract class IntegrationTest extends TestUtils {
return null; return null;
} }
public Object invoke(Method mth, Object... args) { public Object invoke(Method mth, Object... args) throws Exception {
assertNotNull("dynamicCompiler not ready", dynamicCompiler); assertNotNull("dynamicCompiler not ready", dynamicCompiler);
assertNotNull("unknown method", mth); assertNotNull("unknown method", mth);
try { return dynamicCompiler.invoke(mth, args);
return dynamicCompiler.invoke(mth, args);
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
return null;
} }
public File getJarForClass(Class<?> cls) throws IOException { public File getJarForClass(Class<?> cls) throws IOException {
......
...@@ -6,32 +6,71 @@ import javax.tools.JavaFileObject; ...@@ -6,32 +6,71 @@ import javax.tools.JavaFileObject;
import javax.tools.StandardJavaFileManager; import javax.tools.StandardJavaFileManager;
import java.io.IOException; import java.io.IOException;
import java.security.SecureClassLoader; import java.security.SecureClassLoader;
import java.util.HashMap;
import java.util.Map;
import static javax.tools.JavaFileObject.Kind; import static javax.tools.JavaFileObject.Kind;
public class ClassFileManager extends ForwardingJavaFileManager<StandardJavaFileManager> { public class ClassFileManager extends ForwardingJavaFileManager<StandardJavaFileManager> {
private JavaClassObject jClsObject; private DynamicClassLoader classLoader;
public ClassFileManager(StandardJavaFileManager standardManager) { public ClassFileManager(StandardJavaFileManager standardManager) {
super(standardManager); super(standardManager);
classLoader = new DynamicClassLoader();
} }
@Override @Override
public JavaFileObject getJavaFileForOutput(Location location, String className, public JavaFileObject getJavaFileForOutput(Location location, String className,
Kind kind, FileObject sibling) throws IOException { Kind kind, FileObject sibling) throws IOException {
jClsObject = new JavaClassObject(className, kind); JavaClassObject clsObject = new JavaClassObject(className, kind);
return jClsObject; classLoader.getClsMap().put(className, clsObject);
return clsObject;
} }
@Override @Override
public ClassLoader getClassLoader(Location location) { public ClassLoader getClassLoader(Location location) {
return new SecureClassLoader() { return classLoader;
@Override }
protected Class<?> findClass(String name) throws ClassNotFoundException {
byte[] clsBytes = jClsObject.getBytes(); private class DynamicClassLoader extends SecureClassLoader {
return super.defineClass(name, clsBytes, 0, clsBytes.length); private final Map<String, JavaClassObject> clsMap = new HashMap<String, JavaClassObject>();
private final Map<String, Class> clsCache = new HashMap<String, Class>();
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
Class<?> cls = replaceClass(name);
if (cls != null) {
return cls;
}
return super.findClass(name);
}
public Class<?> loadClass(String name) throws ClassNotFoundException {
Class<?> cls = replaceClass(name);
if (cls != null) {
return cls;
} }
}; return super.loadClass(name);
}
public Class<?> replaceClass(String name) throws ClassNotFoundException {
Class cacheCls = clsCache.get(name);
if (cacheCls != null) {
return cacheCls;
}
JavaClassObject clsObject = clsMap.get(name);
if (clsObject == null) {
return null;
}
byte[] clsBytes = clsObject.getBytes();
Class<?> cls = super.defineClass(name, clsBytes, 0, clsBytes.length);
clsCache.put(name, cls);
return cls;
}
public Map<String, JavaClassObject> getClsMap() {
return clsMap;
}
} }
} }
...@@ -38,9 +38,13 @@ public class DynamicCompiler { ...@@ -38,9 +38,13 @@ public class DynamicCompiler {
return Boolean.TRUE.equals(compilerTask.call()); return Boolean.TRUE.equals(compilerTask.call());
} }
private ClassLoader getClassLoader() {
return fileManager.getClassLoader(null);
}
private void makeInstance() throws Exception { private void makeInstance() throws Exception {
String fullName = clsNode.getFullName(); String fullName = clsNode.getFullName();
instance = fileManager.getClassLoader(null).loadClass(fullName).newInstance(); instance = getClassLoader().loadClass(fullName).newInstance();
if (instance == null) { if (instance == null) {
throw new NullPointerException("Instantiation failed"); throw new NullPointerException("Instantiation failed");
} }
...@@ -54,6 +58,9 @@ public class DynamicCompiler { ...@@ -54,6 +58,9 @@ public class DynamicCompiler {
} }
public Method getMethod(String method, Class[] types) throws Exception { public Method getMethod(String method, Class[] types) throws Exception {
for (Class type : types) {
checkType(type);
}
return getInstance().getClass().getMethod(method, types); return getInstance().getClass().getMethod(method, types);
} }
...@@ -61,12 +68,17 @@ public class DynamicCompiler { ...@@ -61,12 +68,17 @@ public class DynamicCompiler {
return mth.invoke(getInstance(), args); return mth.invoke(getInstance(), args);
} }
public Object invoke(String method) throws Exception { private Class<?> checkType(Class type) throws ClassNotFoundException {
return invoke(method, new Class[0]); if (type.isPrimitive()) {
} return type;
}
public Object invoke(String method, Class[] types, Object... args) throws Exception { if (type.isArray()) {
Method mth = getMethod(method, types); return checkType(type.getComponentType());
return invoke(mth, args); }
Class<?> decompiledCls = getClassLoader().loadClass(type.getName());
if (type != decompiledCls) {
throw new IllegalArgumentException("Internal test class cannot be used in method invoke");
}
return decompiledCls;
} }
} }
...@@ -3,8 +3,6 @@ package jadx.tests.integration.enums; ...@@ -3,8 +3,6 @@ package jadx.tests.integration.enums;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest; import jadx.tests.api.IntegrationTest;
import java.lang.reflect.Method;
import org.junit.Test; import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.countString; import static jadx.tests.api.utils.JadxMatchers.countString;
...@@ -27,19 +25,19 @@ public class TestSwitchOverEnum extends IntegrationTest { ...@@ -27,19 +25,19 @@ public class TestSwitchOverEnum extends IntegrationTest {
return 0; return 0;
} }
public void check() {
assertEquals(1, testEnum(Count.ONE));
assertEquals(2, testEnum(Count.TWO));
assertEquals(0, testEnum(Count.THREE));
}
@Test @Test
public void test() { public void test() {
ClassNode cls = getClassNode(TestSwitchOverEnum.class); ClassNode cls = getClassNode(TestSwitchOverEnum.class);
String code = cls.getCode().toString(); String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, countString(1, "synthetic")); assertThat(code, countString(1, "synthetic"));
assertThat(code, countString(2, "switch (c) {")); assertThat(code, countString(2, "switch (c) {"));
assertThat(code, countString(2, "case ONE:")); assertThat(code, countString(2, "case ONE:"));
Method mth = getReflectMethod("testEnum", Count.class);
assertEquals(1, invoke(mth, Count.ONE));
assertEquals(2, invoke(mth, Count.TWO));
assertEquals(0, invoke(mth, Count.THREE));
} }
} }
...@@ -33,10 +33,9 @@ public class TestBreakWithLabel extends IntegrationTest { ...@@ -33,10 +33,9 @@ public class TestBreakWithLabel extends IntegrationTest {
} }
@Test @Test
public void test() { public void test() throws Exception {
ClassNode cls = getClassNode(TestCls.class); ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString(); String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("loop0:")); assertThat(code, containsOne("loop0:"));
assertThat(code, containsOne("break loop0;")); assertThat(code, containsOne("break loop0;"));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment