Commit c1292dff authored by Skylot's avatar Skylot

core refactor: don't use static field in ArgType class

parent 1d81cab4
......@@ -4,6 +4,7 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.exceptions.JadxRuntimeException;
......@@ -57,8 +58,8 @@ public final class FillArrayNode extends InsnNode {
return elemType;
}
public void mergeElementType(ArgType foundElemType) {
ArgType r = ArgType.merge(elemType, foundElemType);
public void mergeElementType(DexNode dex, ArgType foundElemType) {
ArgType r = ArgType.merge(dex, elemType, foundElemType);
if (r != null) {
elemType = r;
}
......
package jadx.core.dex.instructions.args;
import jadx.core.Consts;
import jadx.core.clsp.ClspGraph;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.parser.SignatureParser;
import jadx.core.utils.Utils;
......@@ -45,16 +45,6 @@ public abstract class ArgType {
protected int hash;
private static ClspGraph clsp;
public static void setClsp(ClspGraph clsp) {
ArgType.clsp = clsp;
}
public static boolean isClspSet() {
return ArgType.clsp != null;
}
private static ArgType primitive(PrimitiveType stype) {
return new PrimitiveArg(stype);
}
......@@ -474,29 +464,28 @@ public abstract class ArgType {
public abstract PrimitiveType[] getPossibleTypes();
@Nullable
public static ArgType merge(ArgType a, ArgType b) {
public static ArgType merge(@Nullable DexNode dex, ArgType a, ArgType b) {
if (a == null || b == null) {
return null;
}
if (a.equals(b)) {
return a;
}
ArgType res = mergeInternal(a, b);
ArgType res = mergeInternal(dex, a, b);
if (res == null) {
res = mergeInternal(b, a); // swap
res = mergeInternal(dex, b, a); // swap
}
return res;
}
private static ArgType mergeInternal(ArgType a, ArgType b) {
private static ArgType mergeInternal(@Nullable DexNode dex, ArgType a, ArgType b) {
if (a == UNKNOWN) {
return b;
}
if (a.isArray()) {
return mergeArrays((ArrayArg) a, b);
return mergeArrays(dex, (ArrayArg) a, b);
} else if (b.isArray()) {
return mergeArrays((ArrayArg) b, a);
return mergeArrays(dex, (ArrayArg) b, a);
}
if (!a.isTypeKnown()) {
if (b.isTypeKnown()) {
......@@ -546,7 +535,10 @@ public abstract class ArgType {
if (bObj.equals(Consts.CLASS_OBJECT)) {
return a;
}
String obj = clsp.getCommonAncestor(aObj, bObj);
if (dex == null) {
return null;
}
String obj = dex.root().getClsp().getCommonAncestor(aObj, bObj);
return obj == null ? null : object(obj);
}
if (a.isPrimitive() && b.isPrimitive() && a.getRegCount() == b.getRegCount()) {
......@@ -556,14 +548,14 @@ public abstract class ArgType {
return null;
}
private static ArgType mergeArrays(ArrayArg array, ArgType b) {
private static ArgType mergeArrays(DexNode dex, ArrayArg array, ArgType b) {
if (b.isArray()) {
ArgType ea = array.getArrayElement();
ArgType eb = b.getArrayElement();
if (ea.isPrimitive() && eb.isPrimitive()) {
return OBJECT;
}
ArgType res = merge(ea, eb);
ArgType res = merge(dex, ea, eb);
return res == null ? null : array(res);
}
if (b.contains(PrimitiveType.ARRAY)) {
......@@ -575,25 +567,25 @@ public abstract class ArgType {
return null;
}
public static boolean isCastNeeded(ArgType from, ArgType to) {
public static boolean isCastNeeded(DexNode dex, ArgType from, ArgType to) {
if (from.equals(to)) {
return false;
}
if (from.isObject() && to.isObject()
&& clsp.isImplements(from.getObject(), to.getObject())) {
&& dex.root().getClsp().isImplements(from.getObject(), to.getObject())) {
return false;
}
return true;
}
public static boolean isInstanceOf(ArgType type, ArgType of) {
public static boolean isInstanceOf(DexNode dex, ArgType type, ArgType of) {
if (type.equals(of)) {
return true;
}
if (!type.isObject() || !of.isObject()) {
return false;
}
return clsp.isImplements(type.getObject(), of.getObject());
return dex.root().getClsp().isImplements(type.getObject(), of.getObject());
}
public static ArgType parse(String type) {
......
......@@ -17,7 +17,7 @@ public final class LiteralArg extends InsnArg {
} else if (!type.isTypeKnown()
&& !type.contains(PrimitiveType.LONG)
&& !type.contains(PrimitiveType.DOUBLE)) {
ArgType m = ArgType.merge(type, ArgType.NARROW_NUMBERS);
ArgType m = ArgType.merge(null, type, ArgType.NARROW_NUMBERS);
if (m != null) {
type = m;
}
......
package jadx.core.dex.instructions.args;
import jadx.core.dex.nodes.DexNode;
public abstract class Typed {
protected ArgType type;
......@@ -16,8 +18,8 @@ public abstract class Typed {
return false;
}
public boolean merge(ArgType newType) {
ArgType m = ArgType.merge(type, newType);
public boolean merge(DexNode dex, ArgType newType) {
ArgType m = ArgType.merge(dex, type, newType);
if (m != null && !m.equals(type)) {
setType(m);
return true;
......@@ -25,7 +27,7 @@ public abstract class Typed {
return false;
}
public boolean merge(InsnArg arg) {
return merge(arg.getType());
public boolean merge(DexNode dex, InsnArg arg) {
return merge(dex, arg.getType());
}
}
......@@ -6,7 +6,6 @@ import jadx.api.ResourceType;
import jadx.api.ResourcesLoader;
import jadx.core.clsp.ClspGraph;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.utils.ErrorsCounter;
import jadx.core.utils.exceptions.DecodeException;
import jadx.core.utils.exceptions.JadxException;
......@@ -36,6 +35,7 @@ public class RootNode {
@Nullable
private String appPackage;
private ClassNode appResClass;
private ClspGraph clsp;
public RootNode(IJadxArgs args) {
this.args = args;
......@@ -112,7 +112,7 @@ public class RootNode {
public void initClassPath() throws DecodeException {
try {
if (!ArgType.isClspSet()) {
if (this.clsp == null) {
ClspGraph clsp = new ClspGraph();
clsp.load();
......@@ -122,7 +122,7 @@ public class RootNode {
}
clsp.addApp(classes);
ArgType.setClsp(clsp);
this.clsp = clsp;
}
} catch (IOException e) {
throw new DecodeException("Error loading classpath", e);
......@@ -166,6 +166,10 @@ public class RootNode {
return dexNodes;
}
public ClspGraph getClsp() {
return clsp;
}
public ErrorsCounter getErrorsCounter() {
return errorsCounter;
}
......
......@@ -13,6 +13,7 @@ import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
......@@ -65,7 +66,7 @@ public class ConstInlineVisitor extends AbstractVisitor {
ArgType resType = insn.getResult().getType();
// make sure arg has correct type
if (!arg.getType().isTypeKnown()) {
arg.merge(resType);
arg.merge(mth.dex(), resType);
}
return replaceConst(mth, insn, lit);
}
......@@ -149,30 +150,31 @@ public class ConstInlineVisitor extends AbstractVisitor {
* but contains some expensive operations needed only after constant inline
*/
private static void fixTypes(MethodNode mth, InsnNode insn, LiteralArg litArg) {
DexNode dex = mth.dex();
PostTypeInference.process(mth, insn);
switch (insn.getType()) {
case CONST:
insn.getArg(0).merge(insn.getResult());
insn.getArg(0).merge(dex, insn.getResult());
break;
case MOVE:
insn.getResult().merge(insn.getArg(0));
insn.getArg(0).merge(insn.getResult());
insn.getResult().merge(dex, insn.getArg(0));
insn.getArg(0).merge(dex, insn.getResult());
break;
case IPUT:
case SPUT:
IndexInsnNode node = (IndexInsnNode) insn;
insn.getArg(0).merge(((FieldInfo) node.getIndex()).getType());
insn.getArg(0).merge(dex, ((FieldInfo) node.getIndex()).getType());
break;
case IF: {
InsnArg arg0 = insn.getArg(0);
InsnArg arg1 = insn.getArg(1);
if (arg0 == litArg) {
arg0.merge(arg1);
arg0.merge(dex, arg1);
} else {
arg1.merge(arg0);
arg1.merge(dex, arg0);
}
break;
}
......@@ -181,15 +183,15 @@ public class ConstInlineVisitor extends AbstractVisitor {
InsnArg arg0 = insn.getArg(0);
InsnArg arg1 = insn.getArg(1);
if (arg0 == litArg) {
arg0.merge(arg1);
arg0.merge(dex, arg1);
} else {
arg1.merge(arg0);
arg1.merge(dex, arg0);
}
break;
case RETURN:
if (insn.getArgsCount() != 0) {
insn.getArg(0).merge(mth.getReturnType());
insn.getArg(0).merge(dex, mth.getReturnType());
}
break;
......@@ -207,26 +209,26 @@ public class ConstInlineVisitor extends AbstractVisitor {
} else {
type = mth.getParentClass().getClassInfo().getType();
}
arg.merge(type);
arg.merge(dex, type);
}
k++;
}
break;
case ARITH:
litArg.merge(insn.getResult());
litArg.merge(dex, insn.getResult());
break;
case APUT:
case AGET:
if (litArg == insn.getArg(1)) {
litArg.merge(ArgType.INT);
litArg.merge(dex, ArgType.INT);
}
break;
case NEW_ARRAY:
if (litArg == insn.getArg(0)) {
litArg.merge(ArgType.INT);
litArg.merge(dex, ArgType.INT);
}
break;
......
......@@ -108,9 +108,9 @@ public class ModVisitor extends AbstractVisitor {
if (next < size) {
InsnNode ni = block.getInstructions().get(next);
if (ni.getType() == InsnType.FILL_ARRAY) {
ni.getResult().merge(insn.getResult());
ni.getResult().merge(mth.dex(), insn.getResult());
ArgType arrType = ((NewArrayNode) insn).getArrayType();
((FillArrayNode) ni).mergeElementType(arrType.getArrayElement());
((FillArrayNode) ni).mergeElementType(mth.dex(), arrType.getArrayElement());
remover.add(insn);
}
}
......@@ -263,7 +263,7 @@ public class ModVisitor extends AbstractVisitor {
throw new JadxRuntimeException("Null array element type");
}
}
insn.mergeElementType(elType);
insn.mergeElementType(mth.dex(), elType);
elType = insn.getElementType();
List<LiteralArg> list = insn.getLiteralArgs();
......
......@@ -88,7 +88,7 @@ public class SimplifyVisitor extends AbstractVisitor {
}
}
ArgType castToType = (ArgType) ((IndexInsnNode) insn).getIndex();
if (!ArgType.isCastNeeded(argType, castToType)) {
if (!ArgType.isCastNeeded(mth.dex(), argType, castToType)) {
InsnNode insnNode = new InsnNode(InsnType.MOVE, 1);
insnNode.setOffset(insn.getOffset());
insnNode.setResult(insn.getResult());
......
......@@ -16,6 +16,7 @@ import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
......@@ -254,7 +255,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
} else {
toSkip.add(nextCall);
}
if (iterVar == null || !fixIterableType(iterableArg, iterVar)) {
if (iterVar == null || !fixIterableType(mth.dex(), iterableArg, iterVar)) {
return false;
}
......@@ -266,7 +267,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
return true;
}
private static boolean fixIterableType(InsnArg iterableArg, RegisterArg iterVar) {
private static boolean fixIterableType(DexNode dex, InsnArg iterableArg, RegisterArg iterVar) {
ArgType iterableType = iterableArg.getType();
ArgType varType = iterVar.getType();
if (iterableType.isGeneric()) {
......@@ -282,7 +283,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
iterVar.setType(gType);
return true;
}
if (ArgType.isInstanceOf(gType, varType)) {
if (ArgType.isInstanceOf(dex, gType, varType)) {
return true;
}
LOG.warn("Generic type differs: {} and {}", gType, varType);
......
package jadx.core.dex.visitors.typeinference;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.AbstractVisitor;
......@@ -31,9 +32,10 @@ public class FinishTypeInference extends AbstractVisitor {
} while (change);
// last chance to set correct value (just use first type from 'possible' list)
DexNode dex = mth.dex();
for (BlockNode block : mth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
SelectTypeVisitor.visit(insn);
SelectTypeVisitor.visit(dex, insn);
}
}
......
......@@ -8,6 +8,7 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
......@@ -19,6 +20,7 @@ public class PostTypeInference {
}
public static boolean process(MethodNode mth, InsnNode insn) {
DexNode dex = mth.dex();
switch (insn.getType()) {
case CONST:
RegisterArg res = insn.getResult();
......@@ -34,31 +36,31 @@ public class PostTypeInference {
return true;
}
}
return litArg.merge(res);
return litArg.merge(dex, res);
case MOVE: {
boolean change = false;
if (insn.getResult().merge(insn.getArg(0))) {
if (insn.getResult().merge(dex, insn.getArg(0))) {
change = true;
}
if (insn.getArg(0).merge(insn.getResult())) {
if (insn.getArg(0).merge(dex, insn.getResult())) {
change = true;
}
return change;
}
case AGET:
return fixArrayTypes(insn.getArg(0), insn.getResult());
return fixArrayTypes(dex, insn.getArg(0), insn.getResult());
case APUT:
return fixArrayTypes(insn.getArg(0), insn.getArg(2));
return fixArrayTypes(dex, insn.getArg(0), insn.getArg(2));
case IF: {
boolean change = false;
if (insn.getArg(1).merge(insn.getArg(0))) {
if (insn.getArg(1).merge(dex, insn.getArg(0))) {
change = true;
}
if (insn.getArg(0).merge(insn.getArg(1))) {
if (insn.getArg(0).merge(dex, insn.getArg(1))) {
change = true;
}
return change;
......@@ -138,12 +140,12 @@ public class PostTypeInference {
return false;
}
private static boolean fixArrayTypes(InsnArg array, InsnArg elem) {
private static boolean fixArrayTypes(DexNode dex, InsnArg array, InsnArg elem) {
boolean change = false;
if (!elem.getType().isTypeKnown() && elem.merge(array.getType().getArrayElement())) {
if (!elem.getType().isTypeKnown() && elem.merge(dex, array.getType().getArrayElement())) {
change = true;
}
if (!array.getType().isTypeKnown() && array.merge(ArgType.array(elem.getType()))) {
if (!array.getType().isTypeKnown() && array.merge(dex, ArgType.array(elem.getType()))) {
change = true;
}
return change;
......
......@@ -2,6 +2,7 @@ package jadx.core.dex.visitors.typeinference;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode;
public class SelectTypeVisitor {
......@@ -9,21 +10,21 @@ public class SelectTypeVisitor {
private SelectTypeVisitor() {
}
public static void visit(InsnNode insn) {
public static void visit(DexNode dex, InsnNode insn) {
InsnArg res = insn.getResult();
if (res != null && !res.getType().isTypeKnown()) {
selectType(res);
selectType(dex, res);
}
for (InsnArg arg : insn.getArguments()) {
if (!arg.getType().isTypeKnown()) {
selectType(arg);
selectType(dex, arg);
}
}
}
private static void selectType(InsnArg arg) {
private static void selectType(DexNode dex, InsnArg arg) {
ArgType t = arg.getType();
ArgType newType = ArgType.merge(t, t.selectFirst());
ArgType newType = ArgType.merge(dex, t, t.selectFirst());
arg.setType(newType);
}
......
......@@ -5,6 +5,7 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException;
......@@ -18,9 +19,10 @@ public class TypeInference extends AbstractVisitor {
if (mth.isNoCode()) {
return;
}
DexNode dex = mth.dex();
for (SSAVar var : mth.getSVars()) {
// inference variable type
ArgType type = processType(var);
ArgType type = processType(dex, var);
if (type == null) {
type = ArgType.UNKNOWN;
}
......@@ -40,7 +42,7 @@ public class TypeInference extends AbstractVisitor {
}
}
private static ArgType processType(SSAVar var) {
private static ArgType processType(DexNode dex, SSAVar var) {
RegisterArg assign = var.getAssign();
List<RegisterArg> useList = var.getUseList();
if (useList.isEmpty() || var.isTypeImmutable()) {
......@@ -49,7 +51,7 @@ public class TypeInference extends AbstractVisitor {
ArgType type = assign.getType();
for (RegisterArg arg : useList) {
ArgType useType = arg.getType();
ArgType newType = ArgType.merge(type, useType);
ArgType newType = ArgType.merge(dex, type, useType);
if (newType != null) {
type = newType;
}
......
......@@ -2,6 +2,8 @@ package jadx.tests.functional;
import jadx.core.clsp.ClspGraph;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.DecodeException;
import java.io.IOException;
......@@ -13,19 +15,25 @@ import static jadx.core.dex.instructions.args.ArgType.STRING;
import static jadx.core.dex.instructions.args.ArgType.object;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class JadxClasspathTest {
private static final String JAVA_LANG_EXCEPTION = "java.lang.Exception";
private static final String JAVA_LANG_THROWABLE = "java.lang.Throwable";
ClspGraph clsp;
private DexNode dex;
private ClspGraph clsp;
@Before
public void initClsp() throws IOException, DecodeException {
clsp = new ClspGraph();
clsp.load();
ArgType.setClsp(clsp);
dex = mock(DexNode.class);
RootNode rootNode = mock(RootNode.class);
when(rootNode.getClsp()).thenReturn(clsp);
when(dex.root()).thenReturn(rootNode);
}
@Test
......@@ -36,9 +44,9 @@ public class JadxClasspathTest {
assertTrue(clsp.isImplements(JAVA_LANG_EXCEPTION, JAVA_LANG_THROWABLE));
assertFalse(clsp.isImplements(JAVA_LANG_THROWABLE, JAVA_LANG_EXCEPTION));
assertFalse(ArgType.isCastNeeded(objExc, objThr));
assertTrue(ArgType.isCastNeeded(objThr, objExc));
assertFalse(ArgType.isCastNeeded(dex, objExc, objThr));
assertTrue(ArgType.isCastNeeded(dex, objThr, objExc));
assertTrue(ArgType.isCastNeeded(ArgType.OBJECT, STRING));
assertTrue(ArgType.isCastNeeded(dex, ArgType.OBJECT, STRING));
}
}
......@@ -3,6 +3,8 @@ package jadx.tests.functional;
import jadx.core.clsp.ClspGraph;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.DecodeException;
import java.io.IOException;
......@@ -27,14 +29,21 @@ import static jadx.core.dex.instructions.args.ArgType.unknown;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TypeMergeTest {
private DexNode dex;
@Before
public void initClsp() throws IOException, DecodeException {
ClspGraph clsp = new ClspGraph();
clsp.load();
ArgType.setClsp(clsp);
dex = mock(DexNode.class);
RootNode rootNode = mock(RootNode.class);
when(rootNode.getClsp()).thenReturn(clsp);
when(dex.root()).thenReturn(rootNode);
}
@Test
......@@ -103,7 +112,7 @@ public class TypeMergeTest {
}
private void merge(ArgType t1, ArgType t2, ArgType exp) {
ArgType res = ArgType.merge(t1, t2);
ArgType res = ArgType.merge(dex, t1, t2);
String msg = format(t1, t2, exp, res);
if (exp == null) {
assertNull("Incorrect accept: " + msg, res);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment