Commit 9d2c0e4a authored by Skylot's avatar Skylot

core: fix type inference and const inline for arrays

parent 7277ebb9
......@@ -15,6 +15,7 @@ public enum AFlag {
DONT_WRAP,
DONT_SHRINK,
DONT_INLINE,
DONT_GENERATE,
SKIP,
......
......@@ -9,6 +9,7 @@ import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.exceptions.DecodeException;
import com.android.dex.Code;
......@@ -620,16 +621,19 @@ public class InsnDecoder {
int resReg = getMoveResultRegister(insnArr, offset);
ArgType arrType = dex.getType(insn.getIndex());
ArgType elType = arrType.getArrayElement();
InsnArg[] regs = new InsnArg[insn.getRegisterCount()];
boolean typeImmutable = elType.isPrimitive();
int regsCount = insn.getRegisterCount();
InsnArg[] regs = new InsnArg[regsCount];
if (isRange) {
int r = insn.getA();
for (int i = 0; i < insn.getRegisterCount(); i++) {
regs[i] = InsnArg.reg(r, elType);
for (int i = 0; i < regsCount; i++) {
regs[i] = InsnArg.reg(r, elType, typeImmutable);
r++;
}
} else {
for (int i = 0; i < insn.getRegisterCount(); i++) {
regs[i] = InsnArg.reg(insn, i, elType);
for (int i = 0; i < regsCount; i++) {
int regNum = InsnUtils.getArg(insn, i);
regs[i] = InsnArg.reg(regNum, elType, typeImmutable);
}
}
return insn(InsnType.FILLED_NEW_ARRAY,
......
package jadx.core.dex.instructions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
......@@ -14,6 +15,7 @@ public class PhiInsn extends InsnNode {
for (int i = 0; i < predecessors; i++) {
addReg(regNum, ArgType.UNKNOWN);
}
add(AFlag.DONT_INLINE);
}
@Override
......
......@@ -9,6 +9,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.jetbrains.annotations.Nullable;
public abstract class ArgType {
public static final ArgType INT = primitive(PrimitiveType.INT);
......@@ -93,10 +95,28 @@ public abstract class ArgType {
}
private abstract static class KnownType extends ArgType {
private static final PrimitiveType[] EMPTY_POSSIBLES = new PrimitiveType[0];
@Override
public boolean isTypeKnown() {
return true;
}
@Override
public boolean contains(PrimitiveType type) {
return getPrimitiveType() == type;
}
@Override
public ArgType selectFirst() {
return null;
}
@Override
public PrimitiveType[] getPossibleTypes() {
return EMPTY_POSSIBLES;
}
}
private static final class PrimitiveArg extends KnownType {
......@@ -269,6 +289,7 @@ public abstract class ArgType {
}
private static final class ArrayArg extends KnownType {
public static final PrimitiveType[] ARRAY_POSSIBLES = new PrimitiveType[]{PrimitiveType.ARRAY};
private final ArgType arrayElement;
public ArrayArg(ArgType arrayElement) {
......@@ -292,6 +313,21 @@ public abstract class ArgType {
}
@Override
public boolean isTypeKnown() {
return arrayElement.isTypeKnown();
}
@Override
public ArgType selectFirst() {
return array(arrayElement.selectFirst());
}
@Override
public PrimitiveType[] getPossibleTypes() {
return ARRAY_POSSIBLES;
}
@Override
public int getArrayDimension() {
return 1 + arrayElement.getArrayDimension();
}
......@@ -343,8 +379,10 @@ public abstract class ArgType {
@Override
public ArgType selectFirst() {
PrimitiveType f = possibleTypes[0];
if (f == PrimitiveType.OBJECT || f == PrimitiveType.ARRAY) {
return object(Consts.CLASS_OBJECT);
if (contains(PrimitiveType.OBJECT)) {
return OBJECT;
} else if (contains(PrimitiveType.ARRAY)) {
return array(OBJECT);
} else {
return primitive(f);
}
......@@ -428,18 +466,13 @@ public abstract class ArgType {
return this;
}
public boolean contains(PrimitiveType type) {
throw new UnsupportedOperationException();
}
public abstract boolean contains(PrimitiveType type);
public ArgType selectFirst() {
throw new UnsupportedOperationException();
}
public abstract ArgType selectFirst();
public PrimitiveType[] getPossibleTypes() {
return null;
}
public abstract PrimitiveType[] getPossibleTypes();
@Nullable
public static ArgType merge(ArgType a, ArgType b) {
if (a == null || b == null) {
return null;
......@@ -458,13 +491,18 @@ public abstract class ArgType {
if (a == UNKNOWN) {
return b;
}
if (a.isArray()) {
return mergeArrays((ArrayArg) a, b);
} else if (b.isArray()) {
return mergeArrays((ArrayArg) b, a);
}
if (!a.isTypeKnown()) {
if (b.isTypeKnown()) {
if (a.contains(b.getPrimitiveType())) {
return b;
} else {
return null;
}
return null;
} else {
// both types unknown
List<PrimitiveType> types = new ArrayList<PrimitiveType>();
......@@ -475,7 +513,8 @@ public abstract class ArgType {
}
if (types.isEmpty()) {
return null;
} else if (types.size() == 1) {
}
if (types.size() == 1) {
PrimitiveType nt = types.get(0);
if (nt == PrimitiveType.OBJECT || nt == PrimitiveType.ARRAY) {
return unknown(nt);
......@@ -499,31 +538,15 @@ public abstract class ArgType {
String bObj = b.getObject();
if (aObj.equals(bObj)) {
return a.getGenericTypes() != null ? a : b;
} else if (aObj.equals(Consts.CLASS_OBJECT)) {
}
if (aObj.equals(Consts.CLASS_OBJECT)) {
return b;
} else if (bObj.equals(Consts.CLASS_OBJECT)) {
return a;
} else {
// different objects
String obj = clsp.getCommonAncestor(aObj, bObj);
return obj == null ? null : object(obj);
}
}
if (a.isArray()) {
if (b.isArray()) {
ArgType ea = a.getArrayElement();
ArgType eb = b.getArrayElement();
if (ea.isPrimitive() && eb.isPrimitive()) {
return OBJECT;
} else {
ArgType res = merge(ea, eb);
return res == null ? null : array(res);
}
} else if (b.equals(OBJECT)) {
return OBJECT;
} else {
return null;
if (bObj.equals(Consts.CLASS_OBJECT)) {
return a;
}
String obj = clsp.getCommonAncestor(aObj, bObj);
return obj == null ? null : object(obj);
}
if (a.isPrimitive() && b.isPrimitive() && a.getRegCount() == b.getRegCount()) {
return primitive(PrimitiveType.getSmaller(a.getPrimitiveType(), b.getPrimitiveType()));
......@@ -532,6 +555,25 @@ public abstract class ArgType {
return null;
}
private static ArgType mergeArrays(ArrayArg array, ArgType b) {
if (b.isArray()) {
ArgType ea = array.getArrayElement();
ArgType eb = b.getArrayElement();
if (ea.isPrimitive() && eb.isPrimitive()) {
return OBJECT;
}
ArgType res = merge(ea, eb);
return res == null ? null : array(res);
}
if (b.contains(PrimitiveType.ARRAY)) {
return array;
}
if (b.equals(OBJECT)) {
return OBJECT;
}
return null;
}
public static boolean isCastNeeded(ArgType from, ArgType to) {
if (from.equals(to)) {
return false;
......
......@@ -34,6 +34,10 @@ public abstract class InsnArg extends Typed {
return new TypeImmutableArg(regNum, type);
}
public static RegisterArg reg(int regNum, ArgType type, boolean typeImmutable) {
return typeImmutable ? new TypeImmutableArg(regNum, type) : new RegisterArg(regNum, type);
}
public static LiteralArg lit(long literal, ArgType type) {
return new LiteralArg(literal, type);
}
......
package jadx.core.dex.visitors;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
......@@ -59,6 +60,7 @@ public class BlockProcessingHelper {
RegisterArg resArg = me.getResult();
resArg = InsnArg.reg(resArg.getRegNum(), type);
me.setResult(resArg);
me.add(AFlag.DONT_INLINE);
excHandler.setArg(resArg);
}
......
......@@ -2,7 +2,6 @@ package jadx.core.dex.visitors;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg;
......@@ -213,9 +212,7 @@ public class CodeShrinker extends AbstractVisitor {
continue;
}
InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null
|| assignInsn instanceof PhiInsn
|| assignInsn.getType() == InsnType.MOVE_EXCEPTION) {
if (assignInsn == null || assignInsn.contains(AFlag.DONT_INLINE)) {
continue;
}
int assignPos = insnList.getIndex(assignInsn);
......
package jadx.core.dex.visitors;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
......@@ -52,21 +53,13 @@ public class ConstInlinerVisitor extends AbstractVisitor {
SSAVar sVar = insn.getResult().getSVar();
if (lit == 0) {
// don't inline null object if:
// - used as instance arg in invoke instruction
for (RegisterArg useArg : sVar.getUseList()) {
InsnNode parentInsn = useArg.getParentInsn();
if (parentInsn != null) {
InsnType insnType = parentInsn.getType();
if (insnType == InsnType.INVOKE) {
InvokeNode inv = (InvokeNode) parentInsn;
if (inv.getInvokeType() != InvokeType.STATIC
&& inv.getArg(0) == useArg) {
return false;
}
}
if (checkObjectInline(sVar)) {
if (sVar.getUseCount() == 1) {
insn.getResult().getAssignInsn().add(AFlag.DONT_INLINE);
}
return false;
}
}
ArgType resType = insn.getResult().getType();
// make sure arg has correct type
......@@ -76,6 +69,32 @@ public class ConstInlinerVisitor extends AbstractVisitor {
return replaceConst(mth, sVar, lit);
}
/**
* Don't inline null object if:
* - used as instance arg in invoke instruction
* - used in 'array.length'
*/
private static boolean checkObjectInline(SSAVar sVar) {
for (RegisterArg useArg : sVar.getUseList()) {
InsnNode insn = useArg.getParentInsn();
if (insn != null) {
InsnType insnType = insn.getType();
if (insnType == InsnType.INVOKE) {
InvokeNode inv = (InvokeNode) insn;
if (inv.getInvokeType() != InvokeType.STATIC
&& inv.getArg(0) == useArg) {
return true;
}
} else if (insnType == InsnType.ARRAY_LENGTH) {
if (insn.getArg(0) == useArg) {
return true;
}
}
}
}
return false;
}
private static boolean replaceConst(MethodNode mth, SSAVar sVar, long literal) {
List<RegisterArg> use = new ArrayList<RegisterArg>(sVar.getUseList());
int replaceCount = 0;
......
......@@ -107,7 +107,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
List<RegisterArg> args = new LinkedList<RegisterArg>();
incrInsn.getRegisterArgs(args);
for (RegisterArg iArg : args) {
if (assignOnlyInLoop(mth, loopRegion, (RegisterArg) iArg)) {
if (assignOnlyInLoop(mth, loopRegion, iArg)) {
return false;
}
}
......
......@@ -40,7 +40,7 @@ public class ErrorsCounter {
if (e.getClass() == JadxOverflowException.class) {
// don't print full stack trace
e = new JadxOverflowException(e.getMessage());
LOG.error(msg);
LOG.error(msg + ", message: " + e.getMessage());
} else {
LOG.error(msg, e);
}
......
......@@ -35,11 +35,7 @@ public class InsnUtils {
}
public static String insnTypeToString(InsnType type) {
return insnTypeToString(type.toString());
}
public static String insnTypeToString(String str) {
return String.format("%s ", str);
return type.toString() + " ";
}
public static String indexToString(Object index) {
......@@ -49,7 +45,7 @@ public class InsnUtils {
if (index instanceof String) {
return "\"" + index + "\"";
} else {
return " " + index;
return index.toString();
}
}
}
......@@ -201,6 +201,7 @@ public abstract class IntegrationTest extends TestUtils {
dynamicCompiler = new DynamicCompiler(cls);
boolean result = dynamicCompiler.compile();
assertTrue("Compilation failed", result);
System.out.println("Compilation: PASSED");
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
......
......@@ -20,6 +20,7 @@ import static jadx.core.dex.instructions.args.ArgType.OBJECT;
import static jadx.core.dex.instructions.args.ArgType.STRING;
import static jadx.core.dex.instructions.args.ArgType.UNKNOWN;
import static jadx.core.dex.instructions.args.ArgType.UNKNOWN_OBJECT;
import static jadx.core.dex.instructions.args.ArgType.array;
import static jadx.core.dex.instructions.args.ArgType.genericType;
import static jadx.core.dex.instructions.args.ArgType.object;
import static jadx.core.dex.instructions.args.ArgType.unknown;
......@@ -59,13 +60,6 @@ public class TypeMergeTest {
unknown(PrimitiveType.OBJECT, PrimitiveType.ARRAY),
unknown(PrimitiveType.OBJECT));
check(ArgType.array(INT), ArgType.array(BYTE), ArgType.OBJECT);
first(ArgType.array(INT), ArgType.array(INT));
first(ArgType.array(STRING), ArgType.array(STRING));
first(OBJECT, ArgType.array(INT));
first(OBJECT, ArgType.array(STRING));
ArgType objExc = object("java.lang.Exception");
ArgType objThr = object("java.lang.Throwable");
ArgType objIO = object("java.io.IOException");
......@@ -83,6 +77,18 @@ public class TypeMergeTest {
first(generic, objExc);
}
@Test
public void testArrayMerge() {
check(array(INT), array(BYTE), ArgType.OBJECT);
first(array(INT), array(INT));
first(array(STRING), array(STRING));
first(OBJECT, array(INT));
first(OBJECT, array(STRING));
first(array(unknown(PrimitiveType.INT, PrimitiveType.FLOAT)), unknown(PrimitiveType.ARRAY));
}
private void first(ArgType t1, ArgType t2) {
check(t1, t2, t1);
}
......
......@@ -5,6 +5,7 @@ import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
......@@ -28,13 +29,19 @@ public class TestWrongCode extends IntegrationTest {
@Test
public void test() {
disableCompilation();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, not(containsString("return false.length;")));
assertThat(code, containsString("return null.length;"));
assertThat(code, containsOne("int[] a = null;"));
assertThat(code, containsOne("return a.length;"));
assertThat(code, containsString("return a == 0 ? a : a;"));
}
@Test
public void testNoDebug() {
noDebugInfo();
getClassNode(TestCls.class);
}
}
package jadx.tests.integration.arrays;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestArrays extends IntegrationTest {
public static class TestCls {
public int test1(int i) {
int[] a = new int[]{1, 2, 3, 5};
return a[i];
}
public int test2(int i) {
int[][] a = new int[i][i + 1];
return a.length;
}
}
@Test
public void test() {
noDebugInfo();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("return new int[]{1, 2, 3, 5}[i];"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment