Commit c1292dff authored by Skylot's avatar Skylot

core refactor: don't use static field in ArgType class

parent 1d81cab4
...@@ -4,6 +4,7 @@ import jadx.core.dex.instructions.args.ArgType; ...@@ -4,6 +4,7 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.PrimitiveType; import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.exceptions.JadxRuntimeException; import jadx.core.utils.exceptions.JadxRuntimeException;
...@@ -57,8 +58,8 @@ public final class FillArrayNode extends InsnNode { ...@@ -57,8 +58,8 @@ public final class FillArrayNode extends InsnNode {
return elemType; return elemType;
} }
public void mergeElementType(ArgType foundElemType) { public void mergeElementType(DexNode dex, ArgType foundElemType) {
ArgType r = ArgType.merge(elemType, foundElemType); ArgType r = ArgType.merge(dex, elemType, foundElemType);
if (r != null) { if (r != null) {
elemType = r; elemType = r;
} }
......
package jadx.core.dex.instructions.args; package jadx.core.dex.instructions.args;
import jadx.core.Consts; import jadx.core.Consts;
import jadx.core.clsp.ClspGraph; import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.parser.SignatureParser; import jadx.core.dex.nodes.parser.SignatureParser;
import jadx.core.utils.Utils; import jadx.core.utils.Utils;
...@@ -45,16 +45,6 @@ public abstract class ArgType { ...@@ -45,16 +45,6 @@ public abstract class ArgType {
protected int hash; protected int hash;
private static ClspGraph clsp;
public static void setClsp(ClspGraph clsp) {
ArgType.clsp = clsp;
}
public static boolean isClspSet() {
return ArgType.clsp != null;
}
private static ArgType primitive(PrimitiveType stype) { private static ArgType primitive(PrimitiveType stype) {
return new PrimitiveArg(stype); return new PrimitiveArg(stype);
} }
...@@ -474,29 +464,28 @@ public abstract class ArgType { ...@@ -474,29 +464,28 @@ public abstract class ArgType {
public abstract PrimitiveType[] getPossibleTypes(); public abstract PrimitiveType[] getPossibleTypes();
@Nullable @Nullable
public static ArgType merge(ArgType a, ArgType b) { public static ArgType merge(@Nullable DexNode dex, ArgType a, ArgType b) {
if (a == null || b == null) { if (a == null || b == null) {
return null; return null;
} }
if (a.equals(b)) { if (a.equals(b)) {
return a; return a;
} }
ArgType res = mergeInternal(a, b); ArgType res = mergeInternal(dex, a, b);
if (res == null) { if (res == null) {
res = mergeInternal(b, a); // swap res = mergeInternal(dex, b, a); // swap
} }
return res; return res;
} }
private static ArgType mergeInternal(ArgType a, ArgType b) { private static ArgType mergeInternal(@Nullable DexNode dex, ArgType a, ArgType b) {
if (a == UNKNOWN) { if (a == UNKNOWN) {
return b; return b;
} }
if (a.isArray()) { if (a.isArray()) {
return mergeArrays((ArrayArg) a, b); return mergeArrays(dex, (ArrayArg) a, b);
} else if (b.isArray()) { } else if (b.isArray()) {
return mergeArrays((ArrayArg) b, a); return mergeArrays(dex, (ArrayArg) b, a);
} }
if (!a.isTypeKnown()) { if (!a.isTypeKnown()) {
if (b.isTypeKnown()) { if (b.isTypeKnown()) {
...@@ -546,7 +535,10 @@ public abstract class ArgType { ...@@ -546,7 +535,10 @@ public abstract class ArgType {
if (bObj.equals(Consts.CLASS_OBJECT)) { if (bObj.equals(Consts.CLASS_OBJECT)) {
return a; return a;
} }
String obj = clsp.getCommonAncestor(aObj, bObj); if (dex == null) {
return null;
}
String obj = dex.root().getClsp().getCommonAncestor(aObj, bObj);
return obj == null ? null : object(obj); return obj == null ? null : object(obj);
} }
if (a.isPrimitive() && b.isPrimitive() && a.getRegCount() == b.getRegCount()) { if (a.isPrimitive() && b.isPrimitive() && a.getRegCount() == b.getRegCount()) {
...@@ -556,14 +548,14 @@ public abstract class ArgType { ...@@ -556,14 +548,14 @@ public abstract class ArgType {
return null; return null;
} }
private static ArgType mergeArrays(ArrayArg array, ArgType b) { private static ArgType mergeArrays(DexNode dex, ArrayArg array, ArgType b) {
if (b.isArray()) { if (b.isArray()) {
ArgType ea = array.getArrayElement(); ArgType ea = array.getArrayElement();
ArgType eb = b.getArrayElement(); ArgType eb = b.getArrayElement();
if (ea.isPrimitive() && eb.isPrimitive()) { if (ea.isPrimitive() && eb.isPrimitive()) {
return OBJECT; return OBJECT;
} }
ArgType res = merge(ea, eb); ArgType res = merge(dex, ea, eb);
return res == null ? null : array(res); return res == null ? null : array(res);
} }
if (b.contains(PrimitiveType.ARRAY)) { if (b.contains(PrimitiveType.ARRAY)) {
...@@ -575,25 +567,25 @@ public abstract class ArgType { ...@@ -575,25 +567,25 @@ public abstract class ArgType {
return null; return null;
} }
public static boolean isCastNeeded(ArgType from, ArgType to) { public static boolean isCastNeeded(DexNode dex, ArgType from, ArgType to) {
if (from.equals(to)) { if (from.equals(to)) {
return false; return false;
} }
if (from.isObject() && to.isObject() if (from.isObject() && to.isObject()
&& clsp.isImplements(from.getObject(), to.getObject())) { && dex.root().getClsp().isImplements(from.getObject(), to.getObject())) {
return false; return false;
} }
return true; return true;
} }
public static boolean isInstanceOf(ArgType type, ArgType of) { public static boolean isInstanceOf(DexNode dex, ArgType type, ArgType of) {
if (type.equals(of)) { if (type.equals(of)) {
return true; return true;
} }
if (!type.isObject() || !of.isObject()) { if (!type.isObject() || !of.isObject()) {
return false; return false;
} }
return clsp.isImplements(type.getObject(), of.getObject()); return dex.root().getClsp().isImplements(type.getObject(), of.getObject());
} }
public static ArgType parse(String type) { public static ArgType parse(String type) {
......
...@@ -17,7 +17,7 @@ public final class LiteralArg extends InsnArg { ...@@ -17,7 +17,7 @@ public final class LiteralArg extends InsnArg {
} else if (!type.isTypeKnown() } else if (!type.isTypeKnown()
&& !type.contains(PrimitiveType.LONG) && !type.contains(PrimitiveType.LONG)
&& !type.contains(PrimitiveType.DOUBLE)) { && !type.contains(PrimitiveType.DOUBLE)) {
ArgType m = ArgType.merge(type, ArgType.NARROW_NUMBERS); ArgType m = ArgType.merge(null, type, ArgType.NARROW_NUMBERS);
if (m != null) { if (m != null) {
type = m; type = m;
} }
......
package jadx.core.dex.instructions.args; package jadx.core.dex.instructions.args;
import jadx.core.dex.nodes.DexNode;
public abstract class Typed { public abstract class Typed {
protected ArgType type; protected ArgType type;
...@@ -16,8 +18,8 @@ public abstract class Typed { ...@@ -16,8 +18,8 @@ public abstract class Typed {
return false; return false;
} }
public boolean merge(ArgType newType) { public boolean merge(DexNode dex, ArgType newType) {
ArgType m = ArgType.merge(type, newType); ArgType m = ArgType.merge(dex, type, newType);
if (m != null && !m.equals(type)) { if (m != null && !m.equals(type)) {
setType(m); setType(m);
return true; return true;
...@@ -25,7 +27,7 @@ public abstract class Typed { ...@@ -25,7 +27,7 @@ public abstract class Typed {
return false; return false;
} }
public boolean merge(InsnArg arg) { public boolean merge(DexNode dex, InsnArg arg) {
return merge(arg.getType()); return merge(dex, arg.getType());
} }
} }
...@@ -6,7 +6,6 @@ import jadx.api.ResourceType; ...@@ -6,7 +6,6 @@ import jadx.api.ResourceType;
import jadx.api.ResourcesLoader; import jadx.api.ResourcesLoader;
import jadx.core.clsp.ClspGraph; import jadx.core.clsp.ClspGraph;
import jadx.core.dex.info.ClassInfo; import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.utils.ErrorsCounter; import jadx.core.utils.ErrorsCounter;
import jadx.core.utils.exceptions.DecodeException; import jadx.core.utils.exceptions.DecodeException;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
...@@ -36,6 +35,7 @@ public class RootNode { ...@@ -36,6 +35,7 @@ public class RootNode {
@Nullable @Nullable
private String appPackage; private String appPackage;
private ClassNode appResClass; private ClassNode appResClass;
private ClspGraph clsp;
public RootNode(IJadxArgs args) { public RootNode(IJadxArgs args) {
this.args = args; this.args = args;
...@@ -112,7 +112,7 @@ public class RootNode { ...@@ -112,7 +112,7 @@ public class RootNode {
public void initClassPath() throws DecodeException { public void initClassPath() throws DecodeException {
try { try {
if (!ArgType.isClspSet()) { if (this.clsp == null) {
ClspGraph clsp = new ClspGraph(); ClspGraph clsp = new ClspGraph();
clsp.load(); clsp.load();
...@@ -122,7 +122,7 @@ public class RootNode { ...@@ -122,7 +122,7 @@ public class RootNode {
} }
clsp.addApp(classes); clsp.addApp(classes);
ArgType.setClsp(clsp); this.clsp = clsp;
} }
} catch (IOException e) { } catch (IOException e) {
throw new DecodeException("Error loading classpath", e); throw new DecodeException("Error loading classpath", e);
...@@ -166,6 +166,10 @@ public class RootNode { ...@@ -166,6 +166,10 @@ public class RootNode {
return dexNodes; return dexNodes;
} }
public ClspGraph getClsp() {
return clsp;
}
public ErrorsCounter getErrorsCounter() { public ErrorsCounter getErrorsCounter() {
return errorsCounter; return errorsCounter;
} }
......
...@@ -13,6 +13,7 @@ import jadx.core.dex.instructions.args.PrimitiveType; ...@@ -13,6 +13,7 @@ import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.FieldNode; import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
...@@ -65,7 +66,7 @@ public class ConstInlineVisitor extends AbstractVisitor { ...@@ -65,7 +66,7 @@ public class ConstInlineVisitor extends AbstractVisitor {
ArgType resType = insn.getResult().getType(); ArgType resType = insn.getResult().getType();
// make sure arg has correct type // make sure arg has correct type
if (!arg.getType().isTypeKnown()) { if (!arg.getType().isTypeKnown()) {
arg.merge(resType); arg.merge(mth.dex(), resType);
} }
return replaceConst(mth, insn, lit); return replaceConst(mth, insn, lit);
} }
...@@ -149,30 +150,31 @@ public class ConstInlineVisitor extends AbstractVisitor { ...@@ -149,30 +150,31 @@ public class ConstInlineVisitor extends AbstractVisitor {
* but contains some expensive operations needed only after constant inline * but contains some expensive operations needed only after constant inline
*/ */
private static void fixTypes(MethodNode mth, InsnNode insn, LiteralArg litArg) { private static void fixTypes(MethodNode mth, InsnNode insn, LiteralArg litArg) {
DexNode dex = mth.dex();
PostTypeInference.process(mth, insn); PostTypeInference.process(mth, insn);
switch (insn.getType()) { switch (insn.getType()) {
case CONST: case CONST:
insn.getArg(0).merge(insn.getResult()); insn.getArg(0).merge(dex, insn.getResult());
break; break;
case MOVE: case MOVE:
insn.getResult().merge(insn.getArg(0)); insn.getResult().merge(dex, insn.getArg(0));
insn.getArg(0).merge(insn.getResult()); insn.getArg(0).merge(dex, insn.getResult());
break; break;
case IPUT: case IPUT:
case SPUT: case SPUT:
IndexInsnNode node = (IndexInsnNode) insn; IndexInsnNode node = (IndexInsnNode) insn;
insn.getArg(0).merge(((FieldInfo) node.getIndex()).getType()); insn.getArg(0).merge(dex, ((FieldInfo) node.getIndex()).getType());
break; break;
case IF: { case IF: {
InsnArg arg0 = insn.getArg(0); InsnArg arg0 = insn.getArg(0);
InsnArg arg1 = insn.getArg(1); InsnArg arg1 = insn.getArg(1);
if (arg0 == litArg) { if (arg0 == litArg) {
arg0.merge(arg1); arg0.merge(dex, arg1);
} else { } else {
arg1.merge(arg0); arg1.merge(dex, arg0);
} }
break; break;
} }
...@@ -181,15 +183,15 @@ public class ConstInlineVisitor extends AbstractVisitor { ...@@ -181,15 +183,15 @@ public class ConstInlineVisitor extends AbstractVisitor {
InsnArg arg0 = insn.getArg(0); InsnArg arg0 = insn.getArg(0);
InsnArg arg1 = insn.getArg(1); InsnArg arg1 = insn.getArg(1);
if (arg0 == litArg) { if (arg0 == litArg) {
arg0.merge(arg1); arg0.merge(dex, arg1);
} else { } else {
arg1.merge(arg0); arg1.merge(dex, arg0);
} }
break; break;
case RETURN: case RETURN:
if (insn.getArgsCount() != 0) { if (insn.getArgsCount() != 0) {
insn.getArg(0).merge(mth.getReturnType()); insn.getArg(0).merge(dex, mth.getReturnType());
} }
break; break;
...@@ -207,26 +209,26 @@ public class ConstInlineVisitor extends AbstractVisitor { ...@@ -207,26 +209,26 @@ public class ConstInlineVisitor extends AbstractVisitor {
} else { } else {
type = mth.getParentClass().getClassInfo().getType(); type = mth.getParentClass().getClassInfo().getType();
} }
arg.merge(type); arg.merge(dex, type);
} }
k++; k++;
} }
break; break;
case ARITH: case ARITH:
litArg.merge(insn.getResult()); litArg.merge(dex, insn.getResult());
break; break;
case APUT: case APUT:
case AGET: case AGET:
if (litArg == insn.getArg(1)) { if (litArg == insn.getArg(1)) {
litArg.merge(ArgType.INT); litArg.merge(dex, ArgType.INT);
} }
break; break;
case NEW_ARRAY: case NEW_ARRAY:
if (litArg == insn.getArg(0)) { if (litArg == insn.getArg(0)) {
litArg.merge(ArgType.INT); litArg.merge(dex, ArgType.INT);
} }
break; break;
......
...@@ -108,9 +108,9 @@ public class ModVisitor extends AbstractVisitor { ...@@ -108,9 +108,9 @@ public class ModVisitor extends AbstractVisitor {
if (next < size) { if (next < size) {
InsnNode ni = block.getInstructions().get(next); InsnNode ni = block.getInstructions().get(next);
if (ni.getType() == InsnType.FILL_ARRAY) { if (ni.getType() == InsnType.FILL_ARRAY) {
ni.getResult().merge(insn.getResult()); ni.getResult().merge(mth.dex(), insn.getResult());
ArgType arrType = ((NewArrayNode) insn).getArrayType(); ArgType arrType = ((NewArrayNode) insn).getArrayType();
((FillArrayNode) ni).mergeElementType(arrType.getArrayElement()); ((FillArrayNode) ni).mergeElementType(mth.dex(), arrType.getArrayElement());
remover.add(insn); remover.add(insn);
} }
} }
...@@ -263,7 +263,7 @@ public class ModVisitor extends AbstractVisitor { ...@@ -263,7 +263,7 @@ public class ModVisitor extends AbstractVisitor {
throw new JadxRuntimeException("Null array element type"); throw new JadxRuntimeException("Null array element type");
} }
} }
insn.mergeElementType(elType); insn.mergeElementType(mth.dex(), elType);
elType = insn.getElementType(); elType = insn.getElementType();
List<LiteralArg> list = insn.getLiteralArgs(); List<LiteralArg> list = insn.getLiteralArgs();
......
...@@ -88,7 +88,7 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -88,7 +88,7 @@ public class SimplifyVisitor extends AbstractVisitor {
} }
} }
ArgType castToType = (ArgType) ((IndexInsnNode) insn).getIndex(); ArgType castToType = (ArgType) ((IndexInsnNode) insn).getIndex();
if (!ArgType.isCastNeeded(argType, castToType)) { if (!ArgType.isCastNeeded(mth.dex(), argType, castToType)) {
InsnNode insnNode = new InsnNode(InsnType.MOVE, 1); InsnNode insnNode = new InsnNode(InsnType.MOVE, 1);
insnNode.setOffset(insn.getOffset()); insnNode.setOffset(insn.getOffset());
insnNode.setResult(insn.getResult()); insnNode.setResult(insn.getResult());
......
...@@ -16,6 +16,7 @@ import jadx.core.dex.instructions.args.LiteralArg; ...@@ -16,6 +16,7 @@ import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.IBlock; import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
...@@ -254,7 +255,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -254,7 +255,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
} else { } else {
toSkip.add(nextCall); toSkip.add(nextCall);
} }
if (iterVar == null || !fixIterableType(iterableArg, iterVar)) { if (iterVar == null || !fixIterableType(mth.dex(), iterableArg, iterVar)) {
return false; return false;
} }
...@@ -266,7 +267,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -266,7 +267,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
return true; return true;
} }
private static boolean fixIterableType(InsnArg iterableArg, RegisterArg iterVar) { private static boolean fixIterableType(DexNode dex, InsnArg iterableArg, RegisterArg iterVar) {
ArgType iterableType = iterableArg.getType(); ArgType iterableType = iterableArg.getType();
ArgType varType = iterVar.getType(); ArgType varType = iterVar.getType();
if (iterableType.isGeneric()) { if (iterableType.isGeneric()) {
...@@ -282,7 +283,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -282,7 +283,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
iterVar.setType(gType); iterVar.setType(gType);
return true; return true;
} }
if (ArgType.isInstanceOf(gType, varType)) { if (ArgType.isInstanceOf(dex, gType, varType)) {
return true; return true;
} }
LOG.warn("Generic type differs: {} and {}", gType, varType); LOG.warn("Generic type differs: {} and {}", gType, varType);
......
package jadx.core.dex.visitors.typeinference; package jadx.core.dex.visitors.typeinference;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
...@@ -31,9 +32,10 @@ public class FinishTypeInference extends AbstractVisitor { ...@@ -31,9 +32,10 @@ public class FinishTypeInference extends AbstractVisitor {
} while (change); } while (change);
// last chance to set correct value (just use first type from 'possible' list) // last chance to set correct value (just use first type from 'possible' list)
DexNode dex = mth.dex();
for (BlockNode block : mth.getBasicBlocks()) { for (BlockNode block : mth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) { for (InsnNode insn : block.getInstructions()) {
SelectTypeVisitor.visit(insn); SelectTypeVisitor.visit(dex, insn);
} }
} }
......
...@@ -8,6 +8,7 @@ import jadx.core.dex.instructions.args.ArgType; ...@@ -8,6 +8,7 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
...@@ -19,6 +20,7 @@ public class PostTypeInference { ...@@ -19,6 +20,7 @@ public class PostTypeInference {
} }
public static boolean process(MethodNode mth, InsnNode insn) { public static boolean process(MethodNode mth, InsnNode insn) {
DexNode dex = mth.dex();
switch (insn.getType()) { switch (insn.getType()) {
case CONST: case CONST:
RegisterArg res = insn.getResult(); RegisterArg res = insn.getResult();
...@@ -34,31 +36,31 @@ public class PostTypeInference { ...@@ -34,31 +36,31 @@ public class PostTypeInference {
return true; return true;
} }
} }
return litArg.merge(res); return litArg.merge(dex, res);
case MOVE: { case MOVE: {
boolean change = false; boolean change = false;
if (insn.getResult().merge(insn.getArg(0))) { if (insn.getResult().merge(dex, insn.getArg(0))) {
change = true; change = true;
} }
if (insn.getArg(0).merge(insn.getResult())) { if (insn.getArg(0).merge(dex, insn.getResult())) {
change = true; change = true;
} }
return change; return change;
} }
case AGET: case AGET:
return fixArrayTypes(insn.getArg(0), insn.getResult()); return fixArrayTypes(dex, insn.getArg(0), insn.getResult());
case APUT: case APUT:
return fixArrayTypes(insn.getArg(0), insn.getArg(2)); return fixArrayTypes(dex, insn.getArg(0), insn.getArg(2));
case IF: { case IF: {
boolean change = false; boolean change = false;
if (insn.getArg(1).merge(insn.getArg(0))) { if (insn.getArg(1).merge(dex, insn.getArg(0))) {
change = true; change = true;
} }
if (insn.getArg(0).merge(insn.getArg(1))) { if (insn.getArg(0).merge(dex, insn.getArg(1))) {
change = true; change = true;
} }
return change; return change;
...@@ -138,12 +140,12 @@ public class PostTypeInference { ...@@ -138,12 +140,12 @@ public class PostTypeInference {
return false; return false;
} }
private static boolean fixArrayTypes(InsnArg array, InsnArg elem) { private static boolean fixArrayTypes(DexNode dex, InsnArg array, InsnArg elem) {
boolean change = false; boolean change = false;
if (!elem.getType().isTypeKnown() && elem.merge(array.getType().getArrayElement())) { if (!elem.getType().isTypeKnown() && elem.merge(dex, array.getType().getArrayElement())) {
change = true; change = true;
} }
if (!array.getType().isTypeKnown() && array.merge(ArgType.array(elem.getType()))) { if (!array.getType().isTypeKnown() && array.merge(dex, ArgType.array(elem.getType()))) {
change = true; change = true;
} }
return change; return change;
......
...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors.typeinference; ...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors.typeinference;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
public class SelectTypeVisitor { public class SelectTypeVisitor {
...@@ -9,21 +10,21 @@ public class SelectTypeVisitor { ...@@ -9,21 +10,21 @@ public class SelectTypeVisitor {
private SelectTypeVisitor() { private SelectTypeVisitor() {
} }
public static void visit(InsnNode insn) { public static void visit(DexNode dex, InsnNode insn) {
InsnArg res = insn.getResult(); InsnArg res = insn.getResult();
if (res != null && !res.getType().isTypeKnown()) { if (res != null && !res.getType().isTypeKnown()) {
selectType(res); selectType(dex, res);
} }
for (InsnArg arg : insn.getArguments()) { for (InsnArg arg : insn.getArguments()) {
if (!arg.getType().isTypeKnown()) { if (!arg.getType().isTypeKnown()) {
selectType(arg); selectType(dex, arg);
} }
} }
} }
private static void selectType(InsnArg arg) { private static void selectType(DexNode dex, InsnArg arg) {
ArgType t = arg.getType(); ArgType t = arg.getType();
ArgType newType = ArgType.merge(t, t.selectFirst()); ArgType newType = ArgType.merge(dex, t, t.selectFirst());
arg.setType(newType); arg.setType(newType);
} }
......
...@@ -5,6 +5,7 @@ import jadx.core.dex.instructions.args.ArgType; ...@@ -5,6 +5,7 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
...@@ -18,9 +19,10 @@ public class TypeInference extends AbstractVisitor { ...@@ -18,9 +19,10 @@ public class TypeInference extends AbstractVisitor {
if (mth.isNoCode()) { if (mth.isNoCode()) {
return; return;
} }
DexNode dex = mth.dex();
for (SSAVar var : mth.getSVars()) { for (SSAVar var : mth.getSVars()) {
// inference variable type // inference variable type
ArgType type = processType(var); ArgType type = processType(dex, var);
if (type == null) { if (type == null) {
type = ArgType.UNKNOWN; type = ArgType.UNKNOWN;
} }
...@@ -40,7 +42,7 @@ public class TypeInference extends AbstractVisitor { ...@@ -40,7 +42,7 @@ public class TypeInference extends AbstractVisitor {
} }
} }
private static ArgType processType(SSAVar var) { private static ArgType processType(DexNode dex, SSAVar var) {
RegisterArg assign = var.getAssign(); RegisterArg assign = var.getAssign();
List<RegisterArg> useList = var.getUseList(); List<RegisterArg> useList = var.getUseList();
if (useList.isEmpty() || var.isTypeImmutable()) { if (useList.isEmpty() || var.isTypeImmutable()) {
...@@ -49,7 +51,7 @@ public class TypeInference extends AbstractVisitor { ...@@ -49,7 +51,7 @@ public class TypeInference extends AbstractVisitor {
ArgType type = assign.getType(); ArgType type = assign.getType();
for (RegisterArg arg : useList) { for (RegisterArg arg : useList) {
ArgType useType = arg.getType(); ArgType useType = arg.getType();
ArgType newType = ArgType.merge(type, useType); ArgType newType = ArgType.merge(dex, type, useType);
if (newType != null) { if (newType != null) {
type = newType; type = newType;
} }
......
...@@ -2,6 +2,8 @@ package jadx.tests.functional; ...@@ -2,6 +2,8 @@ package jadx.tests.functional;
import jadx.core.clsp.ClspGraph; import jadx.core.clsp.ClspGraph;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.DecodeException; import jadx.core.utils.exceptions.DecodeException;
import java.io.IOException; import java.io.IOException;
...@@ -13,19 +15,25 @@ import static jadx.core.dex.instructions.args.ArgType.STRING; ...@@ -13,19 +15,25 @@ import static jadx.core.dex.instructions.args.ArgType.STRING;
import static jadx.core.dex.instructions.args.ArgType.object; import static jadx.core.dex.instructions.args.ArgType.object;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class JadxClasspathTest { public class JadxClasspathTest {
private static final String JAVA_LANG_EXCEPTION = "java.lang.Exception"; private static final String JAVA_LANG_EXCEPTION = "java.lang.Exception";
private static final String JAVA_LANG_THROWABLE = "java.lang.Throwable"; private static final String JAVA_LANG_THROWABLE = "java.lang.Throwable";
ClspGraph clsp; private DexNode dex;
private ClspGraph clsp;
@Before @Before
public void initClsp() throws IOException, DecodeException { public void initClsp() throws IOException, DecodeException {
clsp = new ClspGraph(); clsp = new ClspGraph();
clsp.load(); clsp.load();
ArgType.setClsp(clsp); dex = mock(DexNode.class);
RootNode rootNode = mock(RootNode.class);
when(rootNode.getClsp()).thenReturn(clsp);
when(dex.root()).thenReturn(rootNode);
} }
@Test @Test
...@@ -36,9 +44,9 @@ public class JadxClasspathTest { ...@@ -36,9 +44,9 @@ public class JadxClasspathTest {
assertTrue(clsp.isImplements(JAVA_LANG_EXCEPTION, JAVA_LANG_THROWABLE)); assertTrue(clsp.isImplements(JAVA_LANG_EXCEPTION, JAVA_LANG_THROWABLE));
assertFalse(clsp.isImplements(JAVA_LANG_THROWABLE, JAVA_LANG_EXCEPTION)); assertFalse(clsp.isImplements(JAVA_LANG_THROWABLE, JAVA_LANG_EXCEPTION));
assertFalse(ArgType.isCastNeeded(objExc, objThr)); assertFalse(ArgType.isCastNeeded(dex, objExc, objThr));
assertTrue(ArgType.isCastNeeded(objThr, objExc)); assertTrue(ArgType.isCastNeeded(dex, objThr, objExc));
assertTrue(ArgType.isCastNeeded(ArgType.OBJECT, STRING)); assertTrue(ArgType.isCastNeeded(dex, ArgType.OBJECT, STRING));
} }
} }
...@@ -3,6 +3,8 @@ package jadx.tests.functional; ...@@ -3,6 +3,8 @@ package jadx.tests.functional;
import jadx.core.clsp.ClspGraph; import jadx.core.clsp.ClspGraph;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.PrimitiveType; import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.DecodeException; import jadx.core.utils.exceptions.DecodeException;
import java.io.IOException; import java.io.IOException;
...@@ -27,14 +29,21 @@ import static jadx.core.dex.instructions.args.ArgType.unknown; ...@@ -27,14 +29,21 @@ import static jadx.core.dex.instructions.args.ArgType.unknown;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TypeMergeTest { public class TypeMergeTest {
private DexNode dex;
@Before @Before
public void initClsp() throws IOException, DecodeException { public void initClsp() throws IOException, DecodeException {
ClspGraph clsp = new ClspGraph(); ClspGraph clsp = new ClspGraph();
clsp.load(); clsp.load();
ArgType.setClsp(clsp); dex = mock(DexNode.class);
RootNode rootNode = mock(RootNode.class);
when(rootNode.getClsp()).thenReturn(clsp);
when(dex.root()).thenReturn(rootNode);
} }
@Test @Test
...@@ -103,7 +112,7 @@ public class TypeMergeTest { ...@@ -103,7 +112,7 @@ public class TypeMergeTest {
} }
private void merge(ArgType t1, ArgType t2, ArgType exp) { private void merge(ArgType t1, ArgType t2, ArgType exp) {
ArgType res = ArgType.merge(t1, t2); ArgType res = ArgType.merge(dex, t1, t2);
String msg = format(t1, t2, exp, res); String msg = format(t1, t2, exp, res);
if (exp == null) { if (exp == null) {
assertNull("Incorrect accept: " + msg, res); assertNull("Incorrect accept: " + msg, res);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment