Commit 9d2c0e4a authored by Skylot's avatar Skylot

core: fix type inference and const inline for arrays

parent 7277ebb9
...@@ -15,6 +15,7 @@ public enum AFlag { ...@@ -15,6 +15,7 @@ public enum AFlag {
DONT_WRAP, DONT_WRAP,
DONT_SHRINK, DONT_SHRINK,
DONT_INLINE,
DONT_GENERATE, DONT_GENERATE,
SKIP, SKIP,
......
...@@ -9,6 +9,7 @@ import jadx.core.dex.instructions.args.RegisterArg; ...@@ -9,6 +9,7 @@ import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.DexNode; import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.exceptions.DecodeException; import jadx.core.utils.exceptions.DecodeException;
import com.android.dex.Code; import com.android.dex.Code;
...@@ -620,16 +621,19 @@ public class InsnDecoder { ...@@ -620,16 +621,19 @@ public class InsnDecoder {
int resReg = getMoveResultRegister(insnArr, offset); int resReg = getMoveResultRegister(insnArr, offset);
ArgType arrType = dex.getType(insn.getIndex()); ArgType arrType = dex.getType(insn.getIndex());
ArgType elType = arrType.getArrayElement(); ArgType elType = arrType.getArrayElement();
InsnArg[] regs = new InsnArg[insn.getRegisterCount()]; boolean typeImmutable = elType.isPrimitive();
int regsCount = insn.getRegisterCount();
InsnArg[] regs = new InsnArg[regsCount];
if (isRange) { if (isRange) {
int r = insn.getA(); int r = insn.getA();
for (int i = 0; i < insn.getRegisterCount(); i++) { for (int i = 0; i < regsCount; i++) {
regs[i] = InsnArg.reg(r, elType); regs[i] = InsnArg.reg(r, elType, typeImmutable);
r++; r++;
} }
} else { } else {
for (int i = 0; i < insn.getRegisterCount(); i++) { for (int i = 0; i < regsCount; i++) {
regs[i] = InsnArg.reg(insn, i, elType); int regNum = InsnUtils.getArg(insn, i);
regs[i] = InsnArg.reg(regNum, elType, typeImmutable);
} }
} }
return insn(InsnType.FILLED_NEW_ARRAY, return insn(InsnType.FILLED_NEW_ARRAY,
......
package jadx.core.dex.instructions; package jadx.core.dex.instructions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
...@@ -14,6 +15,7 @@ public class PhiInsn extends InsnNode { ...@@ -14,6 +15,7 @@ public class PhiInsn extends InsnNode {
for (int i = 0; i < predecessors; i++) { for (int i = 0; i < predecessors; i++) {
addReg(regNum, ArgType.UNKNOWN); addReg(regNum, ArgType.UNKNOWN);
} }
add(AFlag.DONT_INLINE);
} }
@Override @Override
......
...@@ -9,6 +9,8 @@ import java.util.ArrayList; ...@@ -9,6 +9,8 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import org.jetbrains.annotations.Nullable;
public abstract class ArgType { public abstract class ArgType {
public static final ArgType INT = primitive(PrimitiveType.INT); public static final ArgType INT = primitive(PrimitiveType.INT);
...@@ -93,10 +95,28 @@ public abstract class ArgType { ...@@ -93,10 +95,28 @@ public abstract class ArgType {
} }
private abstract static class KnownType extends ArgType { private abstract static class KnownType extends ArgType {
private static final PrimitiveType[] EMPTY_POSSIBLES = new PrimitiveType[0];
@Override @Override
public boolean isTypeKnown() { public boolean isTypeKnown() {
return true; return true;
} }
@Override
public boolean contains(PrimitiveType type) {
return getPrimitiveType() == type;
}
@Override
public ArgType selectFirst() {
return null;
}
@Override
public PrimitiveType[] getPossibleTypes() {
return EMPTY_POSSIBLES;
}
} }
private static final class PrimitiveArg extends KnownType { private static final class PrimitiveArg extends KnownType {
...@@ -269,6 +289,7 @@ public abstract class ArgType { ...@@ -269,6 +289,7 @@ public abstract class ArgType {
} }
private static final class ArrayArg extends KnownType { private static final class ArrayArg extends KnownType {
public static final PrimitiveType[] ARRAY_POSSIBLES = new PrimitiveType[]{PrimitiveType.ARRAY};
private final ArgType arrayElement; private final ArgType arrayElement;
public ArrayArg(ArgType arrayElement) { public ArrayArg(ArgType arrayElement) {
...@@ -292,6 +313,21 @@ public abstract class ArgType { ...@@ -292,6 +313,21 @@ public abstract class ArgType {
} }
@Override @Override
public boolean isTypeKnown() {
return arrayElement.isTypeKnown();
}
@Override
public ArgType selectFirst() {
return array(arrayElement.selectFirst());
}
@Override
public PrimitiveType[] getPossibleTypes() {
return ARRAY_POSSIBLES;
}
@Override
public int getArrayDimension() { public int getArrayDimension() {
return 1 + arrayElement.getArrayDimension(); return 1 + arrayElement.getArrayDimension();
} }
...@@ -343,8 +379,10 @@ public abstract class ArgType { ...@@ -343,8 +379,10 @@ public abstract class ArgType {
@Override @Override
public ArgType selectFirst() { public ArgType selectFirst() {
PrimitiveType f = possibleTypes[0]; PrimitiveType f = possibleTypes[0];
if (f == PrimitiveType.OBJECT || f == PrimitiveType.ARRAY) { if (contains(PrimitiveType.OBJECT)) {
return object(Consts.CLASS_OBJECT); return OBJECT;
} else if (contains(PrimitiveType.ARRAY)) {
return array(OBJECT);
} else { } else {
return primitive(f); return primitive(f);
} }
...@@ -428,18 +466,13 @@ public abstract class ArgType { ...@@ -428,18 +466,13 @@ public abstract class ArgType {
return this; return this;
} }
public boolean contains(PrimitiveType type) { public abstract boolean contains(PrimitiveType type);
throw new UnsupportedOperationException();
}
public ArgType selectFirst() { public abstract ArgType selectFirst();
throw new UnsupportedOperationException();
}
public PrimitiveType[] getPossibleTypes() { public abstract PrimitiveType[] getPossibleTypes();
return null;
}
@Nullable
public static ArgType merge(ArgType a, ArgType b) { public static ArgType merge(ArgType a, ArgType b) {
if (a == null || b == null) { if (a == null || b == null) {
return null; return null;
...@@ -458,13 +491,18 @@ public abstract class ArgType { ...@@ -458,13 +491,18 @@ public abstract class ArgType {
if (a == UNKNOWN) { if (a == UNKNOWN) {
return b; return b;
} }
if (a.isArray()) {
return mergeArrays((ArrayArg) a, b);
} else if (b.isArray()) {
return mergeArrays((ArrayArg) b, a);
}
if (!a.isTypeKnown()) { if (!a.isTypeKnown()) {
if (b.isTypeKnown()) { if (b.isTypeKnown()) {
if (a.contains(b.getPrimitiveType())) { if (a.contains(b.getPrimitiveType())) {
return b; return b;
} else {
return null;
} }
return null;
} else { } else {
// both types unknown // both types unknown
List<PrimitiveType> types = new ArrayList<PrimitiveType>(); List<PrimitiveType> types = new ArrayList<PrimitiveType>();
...@@ -475,7 +513,8 @@ public abstract class ArgType { ...@@ -475,7 +513,8 @@ public abstract class ArgType {
} }
if (types.isEmpty()) { if (types.isEmpty()) {
return null; return null;
} else if (types.size() == 1) { }
if (types.size() == 1) {
PrimitiveType nt = types.get(0); PrimitiveType nt = types.get(0);
if (nt == PrimitiveType.OBJECT || nt == PrimitiveType.ARRAY) { if (nt == PrimitiveType.OBJECT || nt == PrimitiveType.ARRAY) {
return unknown(nt); return unknown(nt);
...@@ -499,35 +538,38 @@ public abstract class ArgType { ...@@ -499,35 +538,38 @@ public abstract class ArgType {
String bObj = b.getObject(); String bObj = b.getObject();
if (aObj.equals(bObj)) { if (aObj.equals(bObj)) {
return a.getGenericTypes() != null ? a : b; return a.getGenericTypes() != null ? a : b;
} else if (aObj.equals(Consts.CLASS_OBJECT)) { }
if (aObj.equals(Consts.CLASS_OBJECT)) {
return b; return b;
} else if (bObj.equals(Consts.CLASS_OBJECT)) { }
if (bObj.equals(Consts.CLASS_OBJECT)) {
return a; return a;
} else { }
// different objects
String obj = clsp.getCommonAncestor(aObj, bObj); String obj = clsp.getCommonAncestor(aObj, bObj);
return obj == null ? null : object(obj); return obj == null ? null : object(obj);
} }
if (a.isPrimitive() && b.isPrimitive() && a.getRegCount() == b.getRegCount()) {
return primitive(PrimitiveType.getSmaller(a.getPrimitiveType(), b.getPrimitiveType()));
}
} }
if (a.isArray()) { return null;
}
private static ArgType mergeArrays(ArrayArg array, ArgType b) {
if (b.isArray()) { if (b.isArray()) {
ArgType ea = a.getArrayElement(); ArgType ea = array.getArrayElement();
ArgType eb = b.getArrayElement(); ArgType eb = b.getArrayElement();
if (ea.isPrimitive() && eb.isPrimitive()) { if (ea.isPrimitive() && eb.isPrimitive()) {
return OBJECT; return OBJECT;
} else { }
ArgType res = merge(ea, eb); ArgType res = merge(ea, eb);
return res == null ? null : array(res); return res == null ? null : array(res);
} }
} else if (b.equals(OBJECT)) { if (b.contains(PrimitiveType.ARRAY)) {
return OBJECT; return array;
} else {
return null;
}
}
if (a.isPrimitive() && b.isPrimitive() && a.getRegCount() == b.getRegCount()) {
return primitive(PrimitiveType.getSmaller(a.getPrimitiveType(), b.getPrimitiveType()));
} }
if (b.equals(OBJECT)) {
return OBJECT;
} }
return null; return null;
} }
......
...@@ -34,6 +34,10 @@ public abstract class InsnArg extends Typed { ...@@ -34,6 +34,10 @@ public abstract class InsnArg extends Typed {
return new TypeImmutableArg(regNum, type); return new TypeImmutableArg(regNum, type);
} }
public static RegisterArg reg(int regNum, ArgType type, boolean typeImmutable) {
return typeImmutable ? new TypeImmutableArg(regNum, type) : new RegisterArg(regNum, type);
}
public static LiteralArg lit(long literal, ArgType type) { public static LiteralArg lit(long literal, ArgType type) {
return new LiteralArg(literal, type); return new LiteralArg(literal, type);
} }
......
package jadx.core.dex.visitors; package jadx.core.dex.visitors;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.AType;
import jadx.core.dex.instructions.IfNode; import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
...@@ -59,6 +60,7 @@ public class BlockProcessingHelper { ...@@ -59,6 +60,7 @@ public class BlockProcessingHelper {
RegisterArg resArg = me.getResult(); RegisterArg resArg = me.getResult();
resArg = InsnArg.reg(resArg.getRegNum(), type); resArg = InsnArg.reg(resArg.getRegNum(), type);
me.setResult(resArg); me.setResult(resArg);
me.add(AFlag.DONT_INLINE);
excHandler.setArg(resArg); excHandler.setArg(resArg);
} }
......
...@@ -2,7 +2,6 @@ package jadx.core.dex.visitors; ...@@ -2,7 +2,6 @@ package jadx.core.dex.visitors;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg; import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
...@@ -213,9 +212,7 @@ public class CodeShrinker extends AbstractVisitor { ...@@ -213,9 +212,7 @@ public class CodeShrinker extends AbstractVisitor {
continue; continue;
} }
InsnNode assignInsn = sVar.getAssign().getParentInsn(); InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null if (assignInsn == null || assignInsn.contains(AFlag.DONT_INLINE)) {
|| assignInsn instanceof PhiInsn
|| assignInsn.getType() == InsnType.MOVE_EXCEPTION) {
continue; continue;
} }
int assignPos = insnList.getIndex(assignInsn); int assignPos = insnList.getIndex(assignInsn);
......
package jadx.core.dex.visitors; package jadx.core.dex.visitors;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.info.FieldInfo; import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.instructions.IndexInsnNode; import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
...@@ -52,21 +53,13 @@ public class ConstInlinerVisitor extends AbstractVisitor { ...@@ -52,21 +53,13 @@ public class ConstInlinerVisitor extends AbstractVisitor {
SSAVar sVar = insn.getResult().getSVar(); SSAVar sVar = insn.getResult().getSVar();
if (lit == 0) { if (lit == 0) {
// don't inline null object if: if (checkObjectInline(sVar)) {
// - used as instance arg in invoke instruction if (sVar.getUseCount() == 1) {
for (RegisterArg useArg : sVar.getUseList()) { insn.getResult().getAssignInsn().add(AFlag.DONT_INLINE);
InsnNode parentInsn = useArg.getParentInsn();
if (parentInsn != null) {
InsnType insnType = parentInsn.getType();
if (insnType == InsnType.INVOKE) {
InvokeNode inv = (InvokeNode) parentInsn;
if (inv.getInvokeType() != InvokeType.STATIC
&& inv.getArg(0) == useArg) {
return false;
}
}
} }
return false;
} }
} }
ArgType resType = insn.getResult().getType(); ArgType resType = insn.getResult().getType();
// make sure arg has correct type // make sure arg has correct type
...@@ -76,6 +69,32 @@ public class ConstInlinerVisitor extends AbstractVisitor { ...@@ -76,6 +69,32 @@ public class ConstInlinerVisitor extends AbstractVisitor {
return replaceConst(mth, sVar, lit); return replaceConst(mth, sVar, lit);
} }
/**
* Don't inline null object if:
* - used as instance arg in invoke instruction
* - used in 'array.length'
*/
private static boolean checkObjectInline(SSAVar sVar) {
for (RegisterArg useArg : sVar.getUseList()) {
InsnNode insn = useArg.getParentInsn();
if (insn != null) {
InsnType insnType = insn.getType();
if (insnType == InsnType.INVOKE) {
InvokeNode inv = (InvokeNode) insn;
if (inv.getInvokeType() != InvokeType.STATIC
&& inv.getArg(0) == useArg) {
return true;
}
} else if (insnType == InsnType.ARRAY_LENGTH) {
if (insn.getArg(0) == useArg) {
return true;
}
}
}
}
return false;
}
private static boolean replaceConst(MethodNode mth, SSAVar sVar, long literal) { private static boolean replaceConst(MethodNode mth, SSAVar sVar, long literal) {
List<RegisterArg> use = new ArrayList<RegisterArg>(sVar.getUseList()); List<RegisterArg> use = new ArrayList<RegisterArg>(sVar.getUseList());
int replaceCount = 0; int replaceCount = 0;
......
...@@ -107,7 +107,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -107,7 +107,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
List<RegisterArg> args = new LinkedList<RegisterArg>(); List<RegisterArg> args = new LinkedList<RegisterArg>();
incrInsn.getRegisterArgs(args); incrInsn.getRegisterArgs(args);
for (RegisterArg iArg : args) { for (RegisterArg iArg : args) {
if (assignOnlyInLoop(mth, loopRegion, (RegisterArg) iArg)) { if (assignOnlyInLoop(mth, loopRegion, iArg)) {
return false; return false;
} }
} }
......
...@@ -40,7 +40,7 @@ public class ErrorsCounter { ...@@ -40,7 +40,7 @@ public class ErrorsCounter {
if (e.getClass() == JadxOverflowException.class) { if (e.getClass() == JadxOverflowException.class) {
// don't print full stack trace // don't print full stack trace
e = new JadxOverflowException(e.getMessage()); e = new JadxOverflowException(e.getMessage());
LOG.error(msg); LOG.error(msg + ", message: " + e.getMessage());
} else { } else {
LOG.error(msg, e); LOG.error(msg, e);
} }
......
...@@ -35,11 +35,7 @@ public class InsnUtils { ...@@ -35,11 +35,7 @@ public class InsnUtils {
} }
public static String insnTypeToString(InsnType type) { public static String insnTypeToString(InsnType type) {
return insnTypeToString(type.toString()); return type.toString() + " ";
}
public static String insnTypeToString(String str) {
return String.format("%s ", str);
} }
public static String indexToString(Object index) { public static String indexToString(Object index) {
...@@ -49,7 +45,7 @@ public class InsnUtils { ...@@ -49,7 +45,7 @@ public class InsnUtils {
if (index instanceof String) { if (index instanceof String) {
return "\"" + index + "\""; return "\"" + index + "\"";
} else { } else {
return " " + index; return index.toString();
} }
} }
} }
...@@ -201,6 +201,7 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -201,6 +201,7 @@ public abstract class IntegrationTest extends TestUtils {
dynamicCompiler = new DynamicCompiler(cls); dynamicCompiler = new DynamicCompiler(cls);
boolean result = dynamicCompiler.compile(); boolean result = dynamicCompiler.compile();
assertTrue("Compilation failed", result); assertTrue("Compilation failed", result);
System.out.println("Compilation: PASSED");
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
fail(e.getMessage()); fail(e.getMessage());
......
...@@ -20,6 +20,7 @@ import static jadx.core.dex.instructions.args.ArgType.OBJECT; ...@@ -20,6 +20,7 @@ import static jadx.core.dex.instructions.args.ArgType.OBJECT;
import static jadx.core.dex.instructions.args.ArgType.STRING; import static jadx.core.dex.instructions.args.ArgType.STRING;
import static jadx.core.dex.instructions.args.ArgType.UNKNOWN; import static jadx.core.dex.instructions.args.ArgType.UNKNOWN;
import static jadx.core.dex.instructions.args.ArgType.UNKNOWN_OBJECT; import static jadx.core.dex.instructions.args.ArgType.UNKNOWN_OBJECT;
import static jadx.core.dex.instructions.args.ArgType.array;
import static jadx.core.dex.instructions.args.ArgType.genericType; import static jadx.core.dex.instructions.args.ArgType.genericType;
import static jadx.core.dex.instructions.args.ArgType.object; import static jadx.core.dex.instructions.args.ArgType.object;
import static jadx.core.dex.instructions.args.ArgType.unknown; import static jadx.core.dex.instructions.args.ArgType.unknown;
...@@ -59,13 +60,6 @@ public class TypeMergeTest { ...@@ -59,13 +60,6 @@ public class TypeMergeTest {
unknown(PrimitiveType.OBJECT, PrimitiveType.ARRAY), unknown(PrimitiveType.OBJECT, PrimitiveType.ARRAY),
unknown(PrimitiveType.OBJECT)); unknown(PrimitiveType.OBJECT));
check(ArgType.array(INT), ArgType.array(BYTE), ArgType.OBJECT);
first(ArgType.array(INT), ArgType.array(INT));
first(ArgType.array(STRING), ArgType.array(STRING));
first(OBJECT, ArgType.array(INT));
first(OBJECT, ArgType.array(STRING));
ArgType objExc = object("java.lang.Exception"); ArgType objExc = object("java.lang.Exception");
ArgType objThr = object("java.lang.Throwable"); ArgType objThr = object("java.lang.Throwable");
ArgType objIO = object("java.io.IOException"); ArgType objIO = object("java.io.IOException");
...@@ -83,6 +77,18 @@ public class TypeMergeTest { ...@@ -83,6 +77,18 @@ public class TypeMergeTest {
first(generic, objExc); first(generic, objExc);
} }
@Test
public void testArrayMerge() {
check(array(INT), array(BYTE), ArgType.OBJECT);
first(array(INT), array(INT));
first(array(STRING), array(STRING));
first(OBJECT, array(INT));
first(OBJECT, array(STRING));
first(array(unknown(PrimitiveType.INT, PrimitiveType.FLOAT)), unknown(PrimitiveType.ARRAY));
}
private void first(ArgType t1, ArgType t2) { private void first(ArgType t1, ArgType t2) {
check(t1, t2, t1); check(t1, t2, t1);
} }
......
...@@ -5,6 +5,7 @@ import jadx.tests.api.IntegrationTest; ...@@ -5,6 +5,7 @@ import jadx.tests.api.IntegrationTest;
import org.junit.Test; import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
...@@ -28,13 +29,19 @@ public class TestWrongCode extends IntegrationTest { ...@@ -28,13 +29,19 @@ public class TestWrongCode extends IntegrationTest {
@Test @Test
public void test() { public void test() {
disableCompilation();
ClassNode cls = getClassNode(TestCls.class); ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, not(containsString("return false.length;"))); assertThat(code, not(containsString("return false.length;")));
assertThat(code, containsString("return null.length;")); assertThat(code, containsOne("int[] a = null;"));
assertThat(code, containsOne("return a.length;"));
assertThat(code, containsString("return a == 0 ? a : a;")); assertThat(code, containsString("return a == 0 ? a : a;"));
} }
@Test
public void testNoDebug() {
noDebugInfo();
getClassNode(TestCls.class);
}
} }
package jadx.tests.integration.arrays;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestArrays extends IntegrationTest {
public static class TestCls {
public int test1(int i) {
int[] a = new int[]{1, 2, 3, 5};
return a[i];
}
public int test2(int i) {
int[][] a = new int[i][i + 1];
return a.length;
}
}
@Test
public void test() {
noDebugInfo();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("return new int[]{1, 2, 3, 5}[i];"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment