Commit 4ce5cc84 authored by Skylot's avatar Skylot

fix: use multi-variable type search algorithm if type propagation is failed

parent 9b091b7c
...@@ -69,11 +69,24 @@ public class ClspGraph { ...@@ -69,11 +69,24 @@ public class ClspGraph {
return nClass; return nClass;
} }
/**
* @return {@code clsName} instanceof {@code implClsName}
*/
public boolean isImplements(String clsName, String implClsName) { public boolean isImplements(String clsName, String implClsName) {
Set<String> anc = getAncestors(clsName); Set<String> anc = getAncestors(clsName);
return anc.contains(implClsName); return anc.contains(implClsName);
} }
public List<String> getImplementations(String clsName) {
List<String> list = new ArrayList<>();
for (String cls : nameMap.keySet()) {
if (isImplements(cls, clsName)) {
list.add(cls);
}
}
return list;
}
public String getCommonAncestor(String clsName, String implClsName) { public String getCommonAncestor(String clsName, String implClsName) {
if (clsName.equals(implClsName)) { if (clsName.equals(implClsName)) {
return clsName; return clsName;
...@@ -104,7 +117,7 @@ public class ClspGraph { ...@@ -104,7 +117,7 @@ public class ClspGraph {
return null; return null;
} }
private Set<String> getAncestors(String clsName) { public Set<String> getAncestors(String clsName) {
Set<String> result = ancestorCache.get(clsName); Set<String> result = ancestorCache.get(clsName);
if (result != null) { if (result != null) {
return result; return result;
......
...@@ -63,12 +63,7 @@ public class AnnotationGen { ...@@ -63,12 +63,7 @@ public class AnnotationGen {
} }
for (Annotation a : aList.getAll()) { for (Annotation a : aList.getAll()) {
String aCls = a.getAnnotationClass(); String aCls = a.getAnnotationClass();
if (aCls.startsWith(Consts.DALVIK_ANNOTATION_PKG)) { if (!aCls.startsWith(Consts.DALVIK_ANNOTATION_PKG)) {
// skip
if (Consts.DEBUG) {
code.startLine("// " + a);
}
} else {
code.startLine(); code.startLine();
formatAnnotation(code, a); formatAnnotation(code, a);
} }
......
...@@ -701,7 +701,7 @@ public class InsnGen { ...@@ -701,7 +701,7 @@ public class InsnGen {
ArgType origType; ArgType origType;
List<RegisterArg> arguments = callMth.getArguments(false); List<RegisterArg> arguments = callMth.getArguments(false);
if (arguments == null || arguments.isEmpty()) { if (arguments == null || arguments.isEmpty()) {
mth.addComment("JADX WARN: used method not loaded: " + callMth + ", types can be incorrect"); mth.addComment("JADX INFO: used method not loaded: " + callMth + ", types can be incorrect");
origType = callMth.getMethodInfo().getArgumentsTypes().get(origPos); origType = callMth.getMethodInfo().getArgumentsTypes().get(origPos);
} else { } else {
origType = arguments.get(origPos).getInitType(); origType = arguments.get(origPos).getInitType();
......
package jadx.core.dex.instructions; package jadx.core.dex.instructions;
import java.util.Objects;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils; import jadx.core.utils.InsnUtils;
import jadx.core.utils.Utils;
public class IndexInsnNode extends InsnNode { public class IndexInsnNode extends InsnNode {
...@@ -30,11 +33,21 @@ public class IndexInsnNode extends InsnNode { ...@@ -30,11 +33,21 @@ public class IndexInsnNode extends InsnNode {
return false; return false;
} }
IndexInsnNode other = (IndexInsnNode) obj; IndexInsnNode other = (IndexInsnNode) obj;
return index == null ? other.index == null : index.equals(other.index); return Objects.equals(index, other.index);
} }
@Override @Override
public String toString() { public String toString() {
return super.toString() + " " + InsnUtils.indexToString(index); switch (insnType) {
case CAST:
case CHECK_CAST:
return InsnUtils.formatOffset(offset) + ": "
+ InsnUtils.insnTypeToString(insnType)
+ getResult() + " = (" + InsnUtils.indexToString(index) + ") "
+ Utils.listToString(getArguments());
default:
return super.toString() + " " + InsnUtils.indexToString(index);
}
} }
} }
...@@ -693,11 +693,19 @@ public class InsnDecoder { ...@@ -693,11 +693,19 @@ public class InsnDecoder {
} }
private InsnNode arith(DecodedInstruction insn, ArithOp op, ArgType type) { private InsnNode arith(DecodedInstruction insn, ArithOp op, ArgType type) {
return new ArithNode(insn, op, type, false); return new ArithNode(insn, op, fixTypeForBitOps(op, type), false);
} }
private InsnNode arithLit(DecodedInstruction insn, ArithOp op, ArgType type) { private InsnNode arithLit(DecodedInstruction insn, ArithOp op, ArgType type) {
return new ArithNode(insn, op, type, true); return new ArithNode(insn, op, fixTypeForBitOps(op, type), true);
}
private ArgType fixTypeForBitOps(ArithOp op, ArgType type) {
if (type == ArgType.INT
&& (op == ArithOp.AND || op == ArithOp.OR || op == ArithOp.XOR)) {
return ArgType.NARROW_NUMBERS_NO_FLOAT;
}
return type;
} }
private InsnNode neg(DecodedInstruction insn, ArgType type) { private InsnNode neg(DecodedInstruction insn, ArgType type) {
......
...@@ -62,6 +62,7 @@ public final class PhiInsn extends InsnNode { ...@@ -62,6 +62,7 @@ public final class PhiInsn extends InsnNode {
RegisterArg reg = (RegisterArg) arg; RegisterArg reg = (RegisterArg) arg;
if (super.removeArg(reg)) { if (super.removeArg(reg)) {
blockBinds.remove(reg); blockBinds.remove(reg);
reg.getSVar().removeUse(reg);
InstructionRemover.fixUsedInPhiFlag(reg); InstructionRemover.fixUsedInPhiFlag(reg);
return true; return true;
} }
...@@ -78,7 +79,9 @@ public final class PhiInsn extends InsnNode { ...@@ -78,7 +79,9 @@ public final class PhiInsn extends InsnNode {
throw new JadxRuntimeException("Unknown predecessor block by arg " + from + " in PHI: " + this); throw new JadxRuntimeException("Unknown predecessor block by arg " + from + " in PHI: " + this);
} }
if (removeArg(from)) { if (removeArg(from)) {
bindArg((RegisterArg) to, pred); RegisterArg reg = (RegisterArg) to;
bindArg(reg, pred);
reg.getSVar().setUsedInPhi(this);
} }
return true; return true;
} }
......
...@@ -47,6 +47,10 @@ public abstract class ArgType { ...@@ -47,6 +47,10 @@ public abstract class ArgType {
PrimitiveType.INT, PrimitiveType.FLOAT, PrimitiveType.INT, PrimitiveType.FLOAT,
PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.CHAR); PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.CHAR);
public static final ArgType NARROW_NUMBERS_NO_FLOAT = unknown(
PrimitiveType.INT, PrimitiveType.BOOLEAN,
PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.CHAR);
public static final ArgType WIDE = unknown(PrimitiveType.LONG, PrimitiveType.DOUBLE); public static final ArgType WIDE = unknown(PrimitiveType.LONG, PrimitiveType.DOUBLE);
public static final ArgType INT_FLOAT = unknown(PrimitiveType.INT, PrimitiveType.FLOAT); public static final ArgType INT_FLOAT = unknown(PrimitiveType.INT, PrimitiveType.FLOAT);
......
...@@ -3,6 +3,7 @@ package jadx.core.dex.instructions.args; ...@@ -3,6 +3,7 @@ package jadx.core.dex.instructions.args;
import java.util.Objects; import java.util.Objects;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn; import jadx.core.dex.instructions.PhiInsn;
...@@ -118,7 +119,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -118,7 +119,7 @@ public class RegisterArg extends InsnArg implements Named {
return duplicate(getRegNum(), sVar); return duplicate(getRegNum(), sVar);
} }
public RegisterArg duplicate(int regNum, SSAVar sVar) { public RegisterArg duplicate(int regNum, @Nullable SSAVar sVar) {
RegisterArg dup = new RegisterArg(regNum, getInitType()); RegisterArg dup = new RegisterArg(regNum, getInitType());
if (sVar != null) { if (sVar != null) {
dup.setSVar(sVar); dup.setSVar(sVar);
...@@ -140,6 +141,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -140,6 +141,7 @@ public class RegisterArg extends InsnArg implements Named {
return InsnUtils.getConstValueByInsn(dex, parInsn); return InsnUtils.getConstValueByInsn(dex, parInsn);
} }
@Nullable
public InsnNode getAssignInsn() { public InsnNode getAssignInsn() {
if (sVar == null) { if (sVar == null) {
return null; return null;
...@@ -196,7 +198,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -196,7 +198,7 @@ public class RegisterArg extends InsnArg implements Named {
sb.append("(r"); sb.append("(r");
sb.append(regNum); sb.append(regNum);
if (sVar != null) { if (sVar != null) {
sb.append(':').append(sVar.getVersion()); sb.append('v').append(sVar.getVersion());
} }
if (getName() != null) { if (getName() != null) {
sb.append(" '").append(getName()).append('\''); sb.append(" '").append(getName()).append('\'');
......
...@@ -163,7 +163,7 @@ public class SSAVar extends AttrNode { ...@@ -163,7 +163,7 @@ public class SSAVar extends AttrNode {
} }
public String toShortString() { public String toShortString() {
return "r" + regNum + ":" + version; return "r" + regNum + "v" + version;
} }
@Override @Override
......
...@@ -575,6 +575,10 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode { ...@@ -575,6 +575,10 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode {
return debugInfoOffset; return debugInfoOffset;
} }
public SSAVar makeNewSVar(int regNum, @NotNull RegisterArg assignArg) {
return makeNewSVar(regNum, getNextSVarVersion(regNum), assignArg);
}
public SSAVar makeNewSVar(int regNum, int version, @NotNull RegisterArg assignArg) { public SSAVar makeNewSVar(int regNum, int version, @NotNull RegisterArg assignArg) {
SSAVar var = new SSAVar(regNum, version, assignArg); SSAVar var = new SSAVar(regNum, version, assignArg);
if (sVars.isEmpty()) { if (sVars.isEmpty()) {
......
...@@ -59,6 +59,7 @@ public class BlockExceptionHandler extends AbstractVisitor { ...@@ -59,6 +59,7 @@ public class BlockExceptionHandler extends AbstractVisitor {
me.add(AFlag.DONT_INLINE); me.add(AFlag.DONT_INLINE);
resArg.add(AFlag.CUSTOM_DECLARE); resArg.add(AFlag.CUSTOM_DECLARE);
excHandler.setArg(resArg); excHandler.setArg(resArg);
me.addAttr(handlerAttr);
return; return;
} }
} }
......
...@@ -53,11 +53,24 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { ...@@ -53,11 +53,24 @@ public class DebugInfoApplyVisitor extends AbstractVisitor {
applyDebugInfo(mth); applyDebugInfo(mth);
mth.remove(AType.LOCAL_VARS_DEBUG_INFO); mth.remove(AType.LOCAL_VARS_DEBUG_INFO);
} }
checkTypes(mth);
} catch (Exception e) { } catch (Exception e) {
LOG.error("Error to apply debug info: {}", ErrorsCounter.formatMsg(mth, e.getMessage()), e); LOG.error("Error to apply debug info: {}", ErrorsCounter.formatMsg(mth, e.getMessage()), e);
} }
} }
private static void checkTypes(MethodNode mth) {
if (mth.isNoCode() || mth.getSVars().isEmpty()) {
return;
}
mth.getSVars().forEach(var -> {
ArgType type = var.getTypeInfo().getType();
if (!type.isTypeKnown()) {
mth.addComment("JADX WARNING: type inference failed for: " + var.getDetailedVarInfo(mth));
}
});
}
private static void applyDebugInfo(MethodNode mth) { private static void applyDebugInfo(MethodNode mth) {
mth.getSVars().forEach(ssaVar -> collectVarDebugInfo(mth, ssaVar)); mth.getSVars().forEach(ssaVar -> collectVarDebugInfo(mth, ssaVar));
...@@ -80,6 +93,9 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { ...@@ -80,6 +93,9 @@ public class DebugInfoApplyVisitor extends AbstractVisitor {
applyDebugInfo(mth, ssaVar, debugInfo.getRegType(), debugInfo.getName()); applyDebugInfo(mth, ssaVar, debugInfo.getRegType(), debugInfo.getName());
} else { } else {
LOG.warn("Multiple debug info for {}: {}", ssaVar, debugInfoSet); LOG.warn("Multiple debug info for {}: {}", ssaVar, debugInfoSet);
for (RegDebugInfoAttr debugInfo : debugInfoSet) {
applyDebugInfo(mth, ssaVar, debugInfo.getRegType(), debugInfo.getName());
}
} }
} }
...@@ -102,7 +118,7 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { ...@@ -102,7 +118,7 @@ public class DebugInfoApplyVisitor extends AbstractVisitor {
int startAddr = localVar.getStartAddr(); int startAddr = localVar.getStartAddr();
int endAddr = localVar.getEndAddr(); int endAddr = localVar.getEndAddr();
if (isInside(startOffset, startAddr, endAddr) || isInside(endOffset, startAddr, endAddr)) { if (isInside(startOffset, startAddr, endAddr) || isInside(endOffset, startAddr, endAddr)) {
if (Consts.DEBUG && LOG.isDebugEnabled()) { if (Consts.DEBUG) {
LOG.debug("Apply debug info by offset for: {} to {}", ssaVar, localVar); LOG.debug("Apply debug info by offset for: {} to {}", ssaVar, localVar);
} }
applyDebugInfo(mth, ssaVar, localVar.getType(), localVar.getName()); applyDebugInfo(mth, ssaVar, localVar.getType(), localVar.getName());
...@@ -127,15 +143,15 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { ...@@ -127,15 +143,15 @@ public class DebugInfoApplyVisitor extends AbstractVisitor {
} }
public static void applyDebugInfo(MethodNode mth, SSAVar ssaVar, ArgType type, String varName) { public static void applyDebugInfo(MethodNode mth, SSAVar ssaVar, ArgType type, String varName) {
if (NameMapper.isValidIdentifier(varName)) { TypeUpdateResult result = mth.root().getTypeUpdate().applyWithWiderAllow(ssaVar, type);
ssaVar.setName(varName);
}
TypeUpdateResult result = mth.root().getTypeUpdate().applyDebug(ssaVar, type);
if (result == TypeUpdateResult.REJECT) { if (result == TypeUpdateResult.REJECT) {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("Reject debug info of type: {} and name: '{}' for {}, mth: {}", type, varName, ssaVar, mth); LOG.debug("Reject debug info of type: {} and name: '{}' for {}, mth: {}", type, varName, ssaVar, mth);
} }
} else { } else {
if (NameMapper.isValidIdentifier(varName)) {
ssaVar.setName(varName);
}
detachDebugInfo(ssaVar.getAssign()); detachDebugInfo(ssaVar.getAssign());
ssaVar.getUseList().forEach(DebugInfoApplyVisitor::detachDebugInfo); ssaVar.getUseList().forEach(DebugInfoApplyVisitor::detachDebugInfo);
} }
......
...@@ -56,7 +56,7 @@ public class DebugInfoParseVisitor extends AbstractVisitor { ...@@ -56,7 +56,7 @@ public class DebugInfoParseVisitor extends AbstractVisitor {
if (localVars.isEmpty()) { if (localVars.isEmpty()) {
return; return;
} }
if (Consts.DEBUG && LOG.isDebugEnabled()) { if (Consts.DEBUG) {
LOG.debug("Parsed debug info for {}: ", mth); LOG.debug("Parsed debug info for {}: ", mth);
localVars.forEach(v -> LOG.debug(" {}", v)); localVars.forEach(v -> LOG.debug(" {}", v));
} }
......
...@@ -38,6 +38,7 @@ import jadx.core.dex.visitors.JadxVisitor; ...@@ -38,6 +38,7 @@ import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.regions.variables.ProcessVariables; import jadx.core.dex.visitors.regions.variables.ProcessVariables;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.RegionUtils; import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxOverflowException;
@JadxVisitor( @JadxVisitor(
name = "LoopRegionVisitor", name = "LoopRegionVisitor",
...@@ -112,8 +113,12 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -112,8 +113,12 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
List<RegisterArg> args = new LinkedList<>(); List<RegisterArg> args = new LinkedList<>();
incrInsn.getRegisterArgs(args); incrInsn.getRegisterArgs(args);
for (RegisterArg iArg : args) { for (RegisterArg iArg : args) {
if (assignOnlyInLoop(mth, loopRegion, iArg)) { try {
return false; if (assignOnlyInLoop(mth, loopRegion, iArg)) {
return false;
}
} catch (StackOverflowError error) {
throw new JadxOverflowException("LoopRegionVisitor.assignOnlyInLoop endless recursion");
} }
} }
......
...@@ -111,7 +111,7 @@ public class EliminatePhiNodes extends AbstractVisitor { ...@@ -111,7 +111,7 @@ public class EliminatePhiNodes extends AbstractVisitor {
// all checks passed // all checks passed
RegisterArg newAssignArg = oldArg.duplicate(newRegNum, null); RegisterArg newAssignArg = oldArg.duplicate(newRegNum, null);
SSAVar newSVar = mth.makeNewSVar(newRegNum, mth.getNextSVarVersion(newRegNum), newAssignArg); SSAVar newSVar = mth.makeNewSVar(newRegNum, newAssignArg);
newSVar.setName(oldSVar.getName()); newSVar.setName(oldSVar.getName());
mth.root().getTypeUpdate().apply(newSVar, assignArg.getType()); mth.root().getTypeUpdate().apply(newSVar, assignArg.getType());
......
package jadx.core.dex.visitors.typeinference;
import java.util.ArrayList;
import java.util.List;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.Utils;
public abstract class AbstractTypeConstraint implements ITypeConstraint {
protected InsnNode insn;
protected List<SSAVar> relatedVars;
public AbstractTypeConstraint(InsnNode insn, InsnArg arg) {
this.insn = insn;
this.relatedVars = collectRelatedVars(insn, arg);
}
private List<SSAVar> collectRelatedVars(InsnNode insn, InsnArg arg) {
List<SSAVar> list = new ArrayList<>(insn.getArgsCount());
if (insn.getResult() == arg) {
for (InsnArg insnArg : insn.getArguments()) {
if (insnArg.isRegister()) {
list.add(((RegisterArg) insnArg).getSVar());
}
}
} else {
list.add(insn.getResult().getSVar());
for (InsnArg insnArg : insn.getArguments()) {
if (insnArg != arg && insnArg.isRegister()) {
list.add(((RegisterArg) insnArg).getSVar());
}
}
}
return list;
}
@Override
public List<SSAVar> getRelatedVars() {
return relatedVars;
}
@Override
public String toString() {
return "(" + insn.getType() + ":" + Utils.listToString(relatedVars, SSAVar::toShortString) + ")";
}
}
package jadx.core.dex.visitors.typeinference;
import java.util.List;
import jadx.core.dex.instructions.args.SSAVar;
public interface ITypeConstraint {
List<SSAVar> getRelatedVars();
boolean check(TypeSearchState state);
}
...@@ -30,4 +30,12 @@ public enum TypeCompareEnum { ...@@ -30,4 +30,12 @@ public enum TypeCompareEnum {
return this; return this;
} }
} }
public boolean isWider() {
return this == WIDER || this == WIDER_BY_GENERIC;
}
public boolean isNarrow() {
return this == NARROW || this == NARROW_BY_GENERIC;
}
} }
package jadx.core.dex.visitors.typeinference; package jadx.core.dex.visitors.typeinference;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
...@@ -11,23 +13,33 @@ import org.slf4j.Logger; ...@@ -11,23 +13,33 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import jadx.core.Consts; import jadx.core.Consts;
import jadx.core.clsp.ClspGraph;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.IndexInsnNode; import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.PrimitiveType; import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode; import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.trycatch.ExcHandlerAttr;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.ConstInlineVisitor; import jadx.core.dex.visitors.ConstInlineVisitor;
import jadx.core.dex.visitors.JadxVisitor; import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.ssa.SSATransform; import jadx.core.dex.visitors.ssa.SSATransform;
import jadx.core.utils.Utils;
@JadxVisitor( @JadxVisitor(
name = "Type Inference", name = "Type Inference",
desc = "Calculate best types for registers", desc = "Calculate best types for SSA variables",
runAfter = { runAfter = {
SSATransform.class, SSATransform.class,
ConstInlineVisitor.class ConstInlineVisitor.class
...@@ -48,65 +60,92 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -48,65 +60,92 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
if (mth.isNoCode()) { if (mth.isNoCode()) {
return; return;
} }
// collect initial types from assign and usages // collect initial type bounds from assign and usages
mth.getSVars().forEach(this::attachBounds); mth.getSVars().forEach(this::attachBounds);
// start initial type changing mth.getSVars().forEach(this::mergePhiBounds);
// start initial type propagation, check types from bounds
mth.getSVars().forEach(this::setBestType); mth.getSVars().forEach(this::setBestType);
// try all possible types if var type is still unknown // try other types if type is still unknown
mth.getSVars().forEach(var -> { boolean resolved = true;
TypeInfo typeInfo = var.getTypeInfo(); for (SSAVar var : mth.getSVars()) {
ArgType type = typeInfo.getType(); ArgType type = var.getTypeInfo().getType();
if (type != null && !type.isTypeKnown()) { if (!type.isTypeKnown()
if (var.getAssign().isTypeImmutable()) { && !var.getAssign().isTypeImmutable()
mth.addComment("JADX WARNING: type rejected for immutable type: " + var.getDetailedVarInfo(mth)); && !tryDeduceType(mth, var, type)) {
} else { resolved = false;
boolean changed = tryAllTypes(var, type);
if (!changed) {
mth.addComment("JADX WARNING: type inference failed for: " + var.getDetailedVarInfo(mth));
}
}
} }
}); }
if (!resolved) {
for (SSAVar var : new ArrayList<>(mth.getSVars())) {
tryInsertAdditionalInsn(mth, var);
}
runMultiVariableSearch(mth);
}
}
private void runMultiVariableSearch(MethodNode mth) {
long startTime = System.currentTimeMillis();
TypeSearch typeSearch = new TypeSearch(mth);
boolean success;
try {
success = typeSearch.run();
} catch (Exception e) {
success = false;
mth.addWarn("Multi-variable type inference failed. Error: " + Utils.getStackTrace(e));
}
long time = System.currentTimeMillis() - startTime;
mth.addComment("JADX DEBUG: Multi-variable type inference result: " + (success ? "success" : "failure")
+ ", time: " + time + " ms");
} }
private void setBestType(SSAVar ssaVar) { private boolean setBestType(SSAVar ssaVar) {
try { try {
RegisterArg assignArg = ssaVar.getAssign(); RegisterArg assignArg = ssaVar.getAssign();
if (assignArg.isTypeImmutable()) { if (!assignArg.isTypeImmutable()) {
ArgType initType = assignArg.getInitType(); return calculateFromBounds(ssaVar);
TypeUpdateResult result = typeUpdate.apply(ssaVar, initType); }
if (Consts.DEBUG && result == TypeUpdateResult.REJECT && LOG.isDebugEnabled()) { ArgType initType = assignArg.getInitType();
LOG.debug("Initial immutable type set rejected: {} -> {}", ssaVar, initType); TypeUpdateResult result = typeUpdate.apply(ssaVar, initType);
if (result == TypeUpdateResult.REJECT) {
if (Consts.DEBUG) {
LOG.info("Initial immutable type set rejected: {} -> {}", ssaVar, initType);
} }
} else { return false;
calculateFromBounds(ssaVar);
} }
return true;
} catch (Exception e) { } catch (Exception e) {
LOG.error("Failed to calculate best type for var: {}", ssaVar); LOG.error("Failed to calculate best type for var: {}", ssaVar);
return false;
} }
} }
private void calculateFromBounds(SSAVar ssaVar) { private boolean calculateFromBounds(SSAVar ssaVar) {
TypeInfo typeInfo = ssaVar.getTypeInfo(); TypeInfo typeInfo = ssaVar.getTypeInfo();
Set<ITypeBound> bounds = typeInfo.getBounds(); Set<ITypeBound> bounds = typeInfo.getBounds();
Optional<ArgType> bestTypeOpt = selectBestTypeFromBounds(bounds); Optional<ArgType> bestTypeOpt = selectBestTypeFromBounds(bounds);
if (bestTypeOpt.isPresent()) { if (bestTypeOpt.isPresent()) {
ArgType candidateType = bestTypeOpt.get(); ArgType candidateType = bestTypeOpt.get();
TypeUpdateResult result = typeUpdate.apply(ssaVar, candidateType); TypeUpdateResult result = typeUpdate.apply(ssaVar, candidateType);
if (Consts.DEBUG && result == TypeUpdateResult.REJECT && LOG.isDebugEnabled()) { if (result == TypeUpdateResult.REJECT) {
if (ssaVar.getTypeInfo().getType().equals(candidateType)) { if (Consts.DEBUG) {
LOG.warn("Same type rejected: {} -> {}, bounds: {}", ssaVar, candidateType, bounds); if (ssaVar.getTypeInfo().getType().equals(candidateType)) {
} else { LOG.info("Same type rejected: {} -> {}, bounds: {}", ssaVar, candidateType, bounds);
LOG.debug("Type set rejected: {} -> {}, bounds: {}", ssaVar, candidateType, bounds); } else {
LOG.debug("Type set rejected: {} -> {}, bounds: {}", ssaVar, candidateType, bounds);
}
} }
return false;
} }
} else if (!bounds.isEmpty()) { return result == TypeUpdateResult.CHANGED;
LOG.warn("Failed to select best type from bounds: "); }
if (Consts.DEBUG) {
LOG.warn("Failed to select best type from bounds, count={} : ", bounds.size());
for (ITypeBound bound : bounds) { for (ITypeBound bound : bounds) {
LOG.warn(" {}", bound); LOG.warn(" {}", bound);
} }
} }
return false;
} }
private Optional<ArgType> selectBestTypeFromBounds(Set<ITypeBound> bounds) { private Optional<ArgType> selectBestTypeFromBounds(Set<ITypeBound> bounds) {
...@@ -118,37 +157,64 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -118,37 +157,64 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
private void attachBounds(SSAVar var) { private void attachBounds(SSAVar var) {
TypeInfo typeInfo = var.getTypeInfo(); TypeInfo typeInfo = var.getTypeInfo();
typeInfo.getBounds().clear();
RegisterArg assign = var.getAssign(); RegisterArg assign = var.getAssign();
addBound(typeInfo, makeAssignBound(assign)); addAssignBound(typeInfo, assign);
for (RegisterArg regArg : var.getUseList()) { for (RegisterArg regArg : var.getUseList()) {
addBound(typeInfo, makeUseBound(regArg)); addBound(typeInfo, makeUseBound(regArg));
} }
} }
private void mergePhiBounds(SSAVar ssaVar) {
PhiInsn usedInPhi = ssaVar.getUsedInPhi();
if (usedInPhi != null) {
Set<ITypeBound> bounds = ssaVar.getTypeInfo().getBounds();
bounds.addAll(usedInPhi.getResult().getSVar().getTypeInfo().getBounds());
for (InsnArg arg : usedInPhi.getArguments()) {
bounds.addAll(((RegisterArg) arg).getSVar().getTypeInfo().getBounds());
}
}
}
private void addBound(TypeInfo typeInfo, ITypeBound bound) { private void addBound(TypeInfo typeInfo, ITypeBound bound) {
if (bound != null && bound.getType() != ArgType.UNKNOWN) { if (bound != null && bound.getType() != ArgType.UNKNOWN) {
typeInfo.getBounds().add(bound); typeInfo.getBounds().add(bound);
} }
} }
private ITypeBound makeAssignBound(RegisterArg assign) { private void addAssignBound(TypeInfo typeInfo, RegisterArg assign) {
InsnNode insn = assign.getParentInsn(); InsnNode insn = assign.getParentInsn();
if (insn == null || assign.isTypeImmutable()) { if (insn == null || assign.isTypeImmutable()) {
return new TypeBoundConst(BoundEnum.ASSIGN, assign.getInitType()); addBound(typeInfo, new TypeBoundConst(BoundEnum.ASSIGN, assign.getInitType()));
return;
} }
switch (insn.getType()) { switch (insn.getType()) {
case NEW_INSTANCE: case NEW_INSTANCE:
ArgType clsType = (ArgType) ((IndexInsnNode) insn).getIndex(); ArgType clsType = (ArgType) ((IndexInsnNode) insn).getIndex();
return new TypeBoundConst(BoundEnum.ASSIGN, clsType); addBound(typeInfo, new TypeBoundConst(BoundEnum.ASSIGN, clsType));
break;
case CONST: case CONST:
LiteralArg constLit = (LiteralArg) insn.getArg(0); LiteralArg constLit = (LiteralArg) insn.getArg(0);
return new TypeBoundConst(BoundEnum.ASSIGN, constLit.getType()); addBound(typeInfo, new TypeBoundConst(BoundEnum.ASSIGN, constLit.getType()));
break;
case MOVE_EXCEPTION:
ExcHandlerAttr excHandlerAttr = insn.get(AType.EXC_HANDLER);
if (excHandlerAttr != null) {
for (ClassInfo catchType : excHandlerAttr.getHandler().getCatchTypes()) {
addBound(typeInfo, new TypeBoundConst(BoundEnum.ASSIGN, catchType.getType()));
}
} else {
addBound(typeInfo, new TypeBoundConst(BoundEnum.ASSIGN, insn.getResult().getInitType()));
}
break;
default: default:
ArgType type = insn.getResult().getInitType(); ArgType type = insn.getResult().getInitType();
return new TypeBoundConst(BoundEnum.ASSIGN, type); addBound(typeInfo, new TypeBoundConst(BoundEnum.ASSIGN, type));
break;
} }
} }
...@@ -161,7 +227,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -161,7 +227,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
return new TypeBoundConst(BoundEnum.USE, regArg.getInitType()); return new TypeBoundConst(BoundEnum.USE, regArg.getInitType());
} }
private boolean tryAllTypes(SSAVar var, ArgType type) { private boolean tryPossibleTypes(SSAVar var, ArgType type) {
List<ArgType> types = makePossibleTypesList(type); List<ArgType> types = makePossibleTypesList(type);
for (ArgType candidateType : types) { for (ArgType candidateType : types) {
TypeUpdateResult result = typeUpdate.apply(var, candidateType); TypeUpdateResult result = typeUpdate.apply(var, candidateType);
...@@ -180,8 +246,98 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -180,8 +246,98 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
} }
} }
for (PrimitiveType possibleType : type.getPossibleTypes()) { for (PrimitiveType possibleType : type.getPossibleTypes()) {
if (possibleType == PrimitiveType.VOID) {
continue;
}
list.add(ArgType.convertFromPrimitiveType(possibleType)); list.add(ArgType.convertFromPrimitiveType(possibleType));
} }
return list; return list;
} }
private boolean tryDeduceType(MethodNode mth, SSAVar var, @Nullable ArgType type) {
// try best type from bounds again
if (setBestType(var)) {
return true;
}
// try all possible types (useful for primitives)
if (type != null && tryPossibleTypes(var, type)) {
return true;
}
// for objects try super types
if (tryWiderObjects(mth, var)) {
return true;
}
return false;
}
/**
* Add MOVE instruction before PHI in bound blocks to make 'soft' type link.
* This allows to use different types in blocks merged by PHI.
*/
private boolean tryInsertAdditionalInsn(MethodNode mth, SSAVar var) {
if (var.getTypeInfo().getType().isTypeKnown()) {
return false;
}
PhiInsn phiInsn = var.getUsedInPhi();
if (phiInsn == null) {
return false;
}
if (var.getUseCount() == 1) {
InsnNode assignInsn = var.getAssign().getAssignInsn();
if (assignInsn != null && assignInsn.getType() == InsnType.MOVE) {
return false;
}
}
for (Map.Entry<RegisterArg, BlockNode> entry : phiInsn.getBlockBinds().entrySet()) {
RegisterArg reg = entry.getKey();
if (reg.getSVar() == var) {
int regNum = reg.getRegNum();
RegisterArg resultArg = reg.duplicate(regNum, null);
SSAVar newSsaVar = mth.makeNewSVar(regNum, resultArg);
RegisterArg arg = reg.duplicate(regNum, var);
InsnNode moveInsn = new InsnNode(InsnType.MOVE, 1);
moveInsn.setResult(resultArg);
moveInsn.addArg(arg);
moveInsn.add(AFlag.SYNTHETIC);
entry.getValue().getInstructions().add(moveInsn);
phiInsn.replaceArg(reg, reg.duplicate(regNum, newSsaVar));
attachBounds(var);
for (InsnArg phiArg : phiInsn.getArguments()) {
attachBounds(((RegisterArg) phiArg).getSVar());
}
for (InsnArg phiArg : phiInsn.getArguments()) {
mergePhiBounds(((RegisterArg) phiArg).getSVar());
}
return true;
}
}
return false;
}
private boolean tryWiderObjects(MethodNode mth, SSAVar var) {
Set<ArgType> objTypes = new LinkedHashSet<>();
for (ITypeBound bound : var.getTypeInfo().getBounds()) {
ArgType boundType = bound.getType();
if (boundType.isTypeKnown() && boundType.isObject()) {
objTypes.add(boundType);
}
}
if (objTypes.isEmpty()) {
return false;
}
ClspGraph clsp = mth.root().getClsp();
for (ArgType objType : objTypes) {
for (String ancestor : clsp.getAncestors(objType.getObject())) {
ArgType ancestorType = ArgType.object(ancestor);
TypeUpdateResult result = typeUpdate.applyWithWiderAllow(var, ancestorType);
if (result == TypeUpdateResult.CHANGED) {
return true;
}
}
}
return false;
}
} }
...@@ -3,6 +3,8 @@ package jadx.core.dex.visitors.typeinference; ...@@ -3,6 +3,8 @@ package jadx.core.dex.visitors.typeinference;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import org.jetbrains.annotations.NotNull;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
public class TypeInfo { public class TypeInfo {
...@@ -10,6 +12,7 @@ public class TypeInfo { ...@@ -10,6 +12,7 @@ public class TypeInfo {
private final Set<ITypeBound> bounds = new HashSet<>(); private final Set<ITypeBound> bounds = new HashSet<>();
@NotNull
public ArgType getType() { public ArgType getType() {
return type; return type;
} }
......
package jadx.core.dex.visitors.typeinference;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.Consts;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.exceptions.JadxRuntimeException;
/**
* Slow and memory consuming multi-variable type search algorithm.
* Used only if fast type propagation is failed for some variables.
* <p>
* Stages description:
* - find all possible candidate types within bounds
* - build dynamic constraint list for every variable
* - run search by checking all candidates
*/
public class TypeSearch {
private static final Logger LOG = LoggerFactory.getLogger(TypeSearch.class);
private static final int CANDIDATES_COUNT_LIMIT = 10;
private static final int SEARCH_ITERATION_LIMIT = 1_000_000;
private final MethodNode mth;
private final TypeSearchState state;
private final TypeCompare typeCompare;
private final TypeUpdate typeUpdate;
public TypeSearch(MethodNode mth) {
this.mth = mth;
this.state = new TypeSearchState(mth);
this.typeUpdate = mth.root().getTypeUpdate();
this.typeCompare = typeUpdate.getComparator();
}
public boolean run() {
mth.getSVars().forEach(this::fillTypeCandidates);
mth.getSVars().forEach(this::collectConstraints);
// quick search for variables without dependencies
state.getUnresolvedVars().forEach(this::resolveIndependentVariables);
boolean searchSuccess;
List<TypeSearchVarInfo> vars = state.getUnresolvedVars();
if (vars.isEmpty()) {
searchSuccess = true;
} else {
search(vars);
searchSuccess = fullCheck(vars);
if (Consts.DEBUG && !searchSuccess) {
LOG.warn("Multi-variable search failed in {}", mth);
}
}
boolean applySuccess = applyResolvedVars();
return searchSuccess && applySuccess;
}
private boolean applyResolvedVars() {
List<TypeSearchVarInfo> resolvedVars = state.getResolvedVars();
for (TypeSearchVarInfo var : resolvedVars) {
var.getVar().setType(var.getCurrentType());
}
boolean applySuccess = true;
for (TypeSearchVarInfo var : resolvedVars) {
TypeUpdateResult res = typeUpdate.applyWithWiderAllow(var.getVar(), var.getCurrentType());
if (res == TypeUpdateResult.REJECT) {
applySuccess = false;
}
}
return applySuccess;
}
private boolean search(List<TypeSearchVarInfo> vars) {
int len = vars.size();
if (Consts.DEBUG) {
LOG.debug("Run search for {} vars: ", len);
StringBuilder sb = new StringBuilder();
long count = 1;
for (TypeSearchVarInfo var : vars) {
LOG.debug(" {}", var);
int size = var.getCandidateTypes().size();
sb.append(" * ").append(size);
count *= size;
}
sb.append(" = ").append(count);
LOG.debug("--- count = {}, {}", count, sb);
}
// prepare vars
for (TypeSearchVarInfo var : vars) {
var.reset();
}
// check all types combinations
int n = 0;
int i = 0;
while (!fullCheck(vars)) {
TypeSearchVarInfo first = vars.get(i);
if (first.nextType()) {
int k = i + 1;
if (k >= len) {
return false;
}
TypeSearchVarInfo next = vars.get(k);
while (true) {
if (next.nextType()) {
k++;
if (k >= len) {
return false;
}
next = vars.get(k);
} else {
break;
}
}
}
n++;
if (n > SEARCH_ITERATION_LIMIT) {
return false;
}
}
// mark all vars as resolved
for (TypeSearchVarInfo var : vars) {
var.setTypeResolved(true);
}
return true;
}
private boolean resolveIndependentVariables(TypeSearchVarInfo varInfo) {
boolean allRelatedVarsResolved = varInfo.getConstraints().stream()
.flatMap(c -> c.getRelatedVars().stream())
.allMatch(v -> state.getVarInfo(v).isTypeResolved());
if (!allRelatedVarsResolved) {
return false;
}
// variable is independent, run single search
varInfo.reset();
do {
if (singleCheck(varInfo)) {
varInfo.setTypeResolved(true);
return true;
}
} while (!varInfo.nextType());
return false;
}
private boolean fullCheck(List<TypeSearchVarInfo> vars) {
for (TypeSearchVarInfo var : vars) {
if (!singleCheck(var)) {
return false;
}
}
return true;
}
private boolean singleCheck(TypeSearchVarInfo var) {
if (var.isTypeResolved()) {
return true;
}
for (ITypeConstraint constraint : var.getConstraints()) {
if (!constraint.check(state)) {
return false;
}
}
return true;
}
private void fillTypeCandidates(SSAVar ssaVar) {
TypeSearchVarInfo varInfo = state.getVarInfo(ssaVar);
ArgType currentType = ssaVar.getTypeInfo().getType();
if (currentType.isTypeKnown()) {
varInfo.setTypeResolved(true);
varInfo.setCurrentType(currentType);
varInfo.setCandidateTypes(Collections.emptyList());
return;
}
if (ssaVar.getAssign().isTypeImmutable()) {
ArgType initType = ssaVar.getAssign().getInitType();
varInfo.setTypeResolved(true);
varInfo.setCurrentType(initType);
varInfo.setCandidateTypes(Collections.emptyList());
return;
}
Set<ArgType> assigns = new HashSet<>();
Set<ArgType> uses = new HashSet<>();
Set<ITypeBound> bounds = ssaVar.getTypeInfo().getBounds();
for (ITypeBound bound : bounds) {
if (bound.getBound() == BoundEnum.ASSIGN) {
assigns.add(bound.getType());
} else {
uses.add(bound.getType());
}
}
Set<ArgType> candidateTypes = new HashSet<>();
addCandidateTypes(bounds, candidateTypes, assigns);
addCandidateTypes(bounds, candidateTypes, uses);
for (ArgType assignType : assigns) {
addCandidateTypes(bounds, candidateTypes, getWiderTypes(assignType));
}
for (ArgType useType : uses) {
addCandidateTypes(bounds, candidateTypes, getNarrowTypes(useType));
}
int size = candidateTypes.size();
if (size == 0) {
throw new JadxRuntimeException("No candidate types for var: " + ssaVar.getDetailedVarInfo(mth)
+ "\n assigns: " + assigns + "\n uses: " + uses);
}
if (size == 1) {
varInfo.setTypeResolved(true);
varInfo.setCurrentType(candidateTypes.iterator().next());
varInfo.setCandidateTypes(Collections.emptyList());
} else {
varInfo.setTypeResolved(false);
varInfo.setCurrentType(ArgType.UNKNOWN);
ArrayList<ArgType> types = new ArrayList<>(candidateTypes);
types.sort(typeCompare.getComparator());
varInfo.setCandidateTypes(Collections.unmodifiableList(types));
}
}
private void addCandidateTypes(Set<ITypeBound> bounds, Set<ArgType> collectedTypes, Collection<ArgType> candidateTypes) {
for (ArgType candidateType : candidateTypes) {
if (candidateType.isTypeKnown() && typeUpdate.inBounds(bounds, candidateType)) {
collectedTypes.add(candidateType);
if (collectedTypes.size() > CANDIDATES_COUNT_LIMIT) {
return;
}
}
}
}
private List<ArgType> getWiderTypes(ArgType type) {
if (type.isTypeKnown()) {
if (type.isObject()) {
Set<String> ancestors = mth.root().getClsp().getAncestors(type.getObject());
return ancestors.stream().map(ArgType::object).collect(Collectors.toList());
}
} else {
return expandUnknownType(type);
}
return Collections.emptyList();
}
private List<ArgType> getNarrowTypes(ArgType type) {
if (type.isTypeKnown()) {
if (type.isObject()) {
if (type.equals(ArgType.OBJECT)) {
// a lot of objects to return
return Collections.singletonList(ArgType.OBJECT);
}
List<String> impList = mth.root().getClsp().getImplementations(type.getObject());
return impList.stream().map(ArgType::object).collect(Collectors.toList());
}
} else {
return expandUnknownType(type);
}
return Collections.emptyList();
}
private List<ArgType> expandUnknownType(ArgType type) {
List<ArgType> list = new ArrayList<>();
for (PrimitiveType possibleType : type.getPossibleTypes()) {
list.add(ArgType.convertFromPrimitiveType(possibleType));
}
return list;
}
private void collectConstraints(SSAVar var) {
TypeSearchVarInfo varInfo = state.getVarInfo(var);
if (varInfo.isTypeResolved()) {
varInfo.setConstraints(Collections.emptyList());
return;
}
varInfo.setConstraints(new ArrayList<>());
addConstraint(varInfo, makeConstraint(var.getAssign()));
for (RegisterArg regArg : var.getUseList()) {
addConstraint(varInfo, makeConstraint(regArg));
}
}
public static ArgType getArgType(TypeSearchState state, InsnArg arg) {
if (arg.isRegister()) {
RegisterArg reg = (RegisterArg) arg;
return state.getVarInfo(reg.getSVar()).getCurrentType();
}
return arg.getType();
}
private void addConstraint(TypeSearchVarInfo varInfo, ITypeConstraint constraint) {
if (constraint != null) {
varInfo.getConstraints().add(constraint);
}
}
@Nullable
private ITypeConstraint makeConstraint(RegisterArg arg) {
InsnNode insn = arg.getParentInsn();
if (insn == null || arg.isTypeImmutable()) {
return null;
}
switch (insn.getType()) {
case MOVE:
return makeMoveConstraint(insn, arg);
case PHI:
return makePhiConstraint(insn, arg);
default:
return null;
}
}
@Nullable
private ITypeConstraint makeMoveConstraint(InsnNode insn, RegisterArg arg) {
if (!insn.getArg(0).isRegister()) {
return null;
}
return new AbstractTypeConstraint(insn, arg) {
@Override
public boolean check(TypeSearchState state) {
ArgType resType = getArgType(state, insn.getResult());
ArgType argType = getArgType(state, insn.getArg(0));
TypeCompareEnum res = typeCompare.compareTypes(resType, argType);
return res == TypeCompareEnum.EQUAL || res.isWider();
}
};
}
private ITypeConstraint makePhiConstraint(InsnNode insn, RegisterArg arg) {
return new AbstractTypeConstraint(insn, arg) {
@Override
public boolean check(TypeSearchState state) {
ArgType resType = getArgType(state, insn.getResult());
for (InsnArg insnArg : insn.getArguments()) {
ArgType argType = getArgType(state, insnArg);
if (!argType.equals(resType)) {
return false;
}
}
return true;
}
};
}
}
package jadx.core.dex.visitors.typeinference;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class TypeSearchState {
private Map<SSAVar, TypeSearchVarInfo> varInfoMap;
public TypeSearchState(MethodNode mth) {
List<SSAVar> vars = mth.getSVars();
this.varInfoMap = new LinkedHashMap<>(vars.size());
for (SSAVar var : vars) {
varInfoMap.put(var, new TypeSearchVarInfo(var));
}
}
@NotNull
public TypeSearchVarInfo getVarInfo(SSAVar var) {
TypeSearchVarInfo varInfo = this.varInfoMap.get(var);
if (varInfo == null) {
throw new JadxRuntimeException("TypeSearchVarInfo not found in map for var: " + var);
}
return varInfo;
}
public List<TypeSearchVarInfo> getAllVars() {
return new ArrayList<>(varInfoMap.values());
}
public List<TypeSearchVarInfo> getUnresolvedVars() {
return varInfoMap.values().stream()
.filter(varInfo -> !varInfo.isTypeResolved())
.collect(Collectors.toList());
}
public List<TypeSearchVarInfo> getResolvedVars() {
return varInfoMap.values().stream()
.filter(TypeSearchVarInfo::isTypeResolved)
.collect(Collectors.toList());
}
}
package jadx.core.dex.visitors.typeinference;
import java.util.List;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.SSAVar;
public class TypeSearchVarInfo {
private final SSAVar var;
private boolean typeResolved;
private ArgType currentType;
private List<ArgType> candidateTypes;
private int currentIndex = -1;
private List<ITypeConstraint> constraints;
public TypeSearchVarInfo(SSAVar var) {
this.var = var;
}
public void reset() {
if (typeResolved) {
return;
}
currentIndex = 0;
currentType = candidateTypes.get(0);
}
/**
* Switch {@code currentType} to next candidate
*
* @return true - if this is the first candidate
*/
public boolean nextType() {
if (typeResolved) {
return false;
}
int len = candidateTypes.size();
currentIndex = (currentIndex + 1) % len;
currentType = candidateTypes.get(currentIndex);
return currentIndex == 0;
}
public SSAVar getVar() {
return var;
}
public boolean isTypeResolved() {
return typeResolved;
}
public void setTypeResolved(boolean typeResolved) {
this.typeResolved = typeResolved;
}
public ArgType getCurrentType() {
return currentType;
}
public void setCurrentType(ArgType currentType) {
this.currentType = currentType;
}
public List<ArgType> getCandidateTypes() {
return candidateTypes;
}
public void setCandidateTypes(List<ArgType> candidateTypes) {
this.candidateTypes = candidateTypes;
}
public List<ITypeConstraint> getConstraints() {
return constraints;
}
public void setConstraints(List<ITypeConstraint> constraints) {
this.constraints = constraints;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("TypeSearchVarInfo{");
sb.append(var.toShortString());
if (typeResolved) {
sb.append(", resolved type: ").append(currentType);
} else {
sb.append(", currentType=").append(currentType);
sb.append(", candidateTypes=").append(candidateTypes);
sb.append(", constraints=").append(constraints);
}
sb.append('}');
return sb.toString();
}
}
...@@ -34,22 +34,16 @@ public final class TypeUpdate { ...@@ -34,22 +34,16 @@ public final class TypeUpdate {
private final Map<InsnType, ITypeListener> listenerRegistry; private final Map<InsnType, ITypeListener> listenerRegistry;
private final TypeCompare comparator; private final TypeCompare comparator;
private ThreadLocal<Boolean> applyDebug = new ThreadLocal<>(); private ThreadLocal<Boolean> allowWider = new ThreadLocal<>();
public TypeUpdate(RootNode root) { public TypeUpdate(RootNode root) {
this.listenerRegistry = initListenerRegistry(); this.listenerRegistry = initListenerRegistry();
this.comparator = new TypeCompare(root); this.comparator = new TypeCompare(root);
} }
public TypeUpdateResult applyDebug(SSAVar ssaVar, ArgType candidateType) { /**
try { * Perform recursive type checking and type propagation for all related variables
applyDebug.set(true); */
return apply(ssaVar, candidateType);
} finally {
applyDebug.set(false);
}
}
public TypeUpdateResult apply(SSAVar ssaVar, ArgType candidateType) { public TypeUpdateResult apply(SSAVar ssaVar, ArgType candidateType) {
if (candidateType == null) { if (candidateType == null) {
return REJECT; return REJECT;
...@@ -71,10 +65,21 @@ public final class TypeUpdate { ...@@ -71,10 +65,21 @@ public final class TypeUpdate {
return CHANGED; return CHANGED;
} }
/**
* Allow wider types for apply from debug info and some special cases
*/
public TypeUpdateResult applyWithWiderAllow(SSAVar ssaVar, ArgType candidateType) {
try {
allowWider.set(true);
return apply(ssaVar, candidateType);
} finally {
allowWider.set(false);
}
}
private TypeUpdateResult updateTypeChecked(TypeUpdateInfo updateInfo, InsnArg arg, ArgType candidateType) { private TypeUpdateResult updateTypeChecked(TypeUpdateInfo updateInfo, InsnArg arg, ArgType candidateType) {
if (candidateType == null) { if (candidateType == null) {
LOG.warn("Reject null type update, arg: {}, info: {}", arg, updateInfo, new RuntimeException()); throw new JadxRuntimeException("Null type update for arg: " + arg);
return REJECT;
} }
ArgType currentType = arg.getType(); ArgType currentType = arg.getType();
if (Objects.equals(currentType, candidateType)) { if (Objects.equals(currentType, candidateType)) {
...@@ -82,15 +87,20 @@ public final class TypeUpdate { ...@@ -82,15 +87,20 @@ public final class TypeUpdate {
} }
TypeCompareEnum compareResult = comparator.compareTypes(candidateType, currentType); TypeCompareEnum compareResult = comparator.compareTypes(candidateType, currentType);
if (compareResult == TypeCompareEnum.CONFLICT) { if (compareResult == TypeCompareEnum.CONFLICT) {
if (Consts.DEBUG) {
LOG.debug("Type rejected for {} due to conflict: candidate={}, current={}", arg, candidateType, currentType);
}
return REJECT; return REJECT;
} }
if (arg.isTypeImmutable() && currentType != ArgType.UNKNOWN) { if (arg.isTypeImmutable() && currentType != ArgType.UNKNOWN) {
// don't changed type, conflict already rejected // don't changed type, conflict already rejected
return SAME; return SAME;
} }
if (compareResult == TypeCompareEnum.WIDER || compareResult == TypeCompareEnum.WIDER_BY_GENERIC) { if (compareResult.isWider()) {
// allow wider types for apply from debug info if (allowWider.get() != Boolean.TRUE) {
if (applyDebug.get() != Boolean.TRUE) { if (Consts.DEBUG) {
LOG.debug("Type rejected for {}: candidate={} is wider than current={}", arg, candidateType, currentType);
}
return REJECT; return REJECT;
} }
} }
...@@ -104,7 +114,7 @@ public final class TypeUpdate { ...@@ -104,7 +114,7 @@ public final class TypeUpdate {
private TypeUpdateResult updateTypeForSsaVar(TypeUpdateInfo updateInfo, SSAVar ssaVar, ArgType candidateType) { private TypeUpdateResult updateTypeForSsaVar(TypeUpdateInfo updateInfo, SSAVar ssaVar, ArgType candidateType) {
TypeInfo typeInfo = ssaVar.getTypeInfo(); TypeInfo typeInfo = ssaVar.getTypeInfo();
if (!inBounds(typeInfo.getBounds(), candidateType)) { if (!inBounds(typeInfo.getBounds(), candidateType)) {
if (Consts.DEBUG && LOG.isDebugEnabled()) { if (Consts.DEBUG) {
LOG.debug("Reject type '{}' for {} by bounds: {}", candidateType, ssaVar, typeInfo.getBounds()); LOG.debug("Reject type '{}' for {} by bounds: {}", candidateType, ssaVar, typeInfo.getBounds());
} }
return REJECT; return REJECT;
...@@ -138,7 +148,11 @@ public final class TypeUpdate { ...@@ -138,7 +148,11 @@ public final class TypeUpdate {
} }
updateInfo.requestUpdate(arg, candidateType); updateInfo.requestUpdate(arg, candidateType);
try { try {
return runListeners(updateInfo, arg, candidateType); TypeUpdateResult result = runListeners(updateInfo, arg, candidateType);
if (result == REJECT) {
updateInfo.rollbackUpdate(arg);
}
return result;
} catch (StackOverflowError overflow) { } catch (StackOverflowError overflow) {
throw new JadxOverflowException("Type update terminated with stack overflow, arg: " + arg); throw new JadxOverflowException("Type update terminated with stack overflow, arg: " + arg);
} }
...@@ -156,7 +170,7 @@ public final class TypeUpdate { ...@@ -156,7 +170,7 @@ public final class TypeUpdate {
return listener.update(updateInfo, insn, arg, candidateType); return listener.update(updateInfo, insn, arg, candidateType);
} }
private boolean inBounds(Set<ITypeBound> bounds, ArgType candidateType) { boolean inBounds(Set<ITypeBound> bounds, ArgType candidateType) {
for (ITypeBound bound : bounds) { for (ITypeBound bound : bounds) {
ArgType boundType = bound.getType(); ArgType boundType = bound.getType();
if (boundType != null && !checkBound(candidateType, bound, boundType)) { if (boundType != null && !checkBound(candidateType, bound, boundType)) {
...@@ -166,6 +180,14 @@ public final class TypeUpdate { ...@@ -166,6 +180,14 @@ public final class TypeUpdate {
return true; return true;
} }
private boolean inBounds(InsnArg arg, ArgType candidateType) {
if (arg.isRegister()) {
TypeInfo typeInfo = ((RegisterArg) arg).getSVar().getTypeInfo();
return inBounds(typeInfo.getBounds(), candidateType);
}
return arg.getType().equals(candidateType);
}
private boolean checkBound(ArgType candidateType, ITypeBound bound, ArgType boundType) { private boolean checkBound(ArgType candidateType, ITypeBound bound, ArgType boundType) {
TypeCompareEnum compareResult = comparator.compareTypes(candidateType, boundType); TypeCompareEnum compareResult = comparator.compareTypes(candidateType, boundType);
switch (compareResult) { switch (compareResult) {
...@@ -222,7 +244,7 @@ public final class TypeUpdate { ...@@ -222,7 +244,7 @@ public final class TypeUpdate {
private Map<InsnType, ITypeListener> initListenerRegistry() { private Map<InsnType, ITypeListener> initListenerRegistry() {
Map<InsnType, ITypeListener> registry = new EnumMap<>(InsnType.class); Map<InsnType, ITypeListener> registry = new EnumMap<>(InsnType.class);
registry.put(InsnType.CONST, this::sameFirstArgListener); registry.put(InsnType.CONST, this::sameFirstArgListener);
registry.put(InsnType.MOVE, this::sameFirstArgListener); registry.put(InsnType.MOVE, this::moveListener);
registry.put(InsnType.PHI, this::allSameListener); registry.put(InsnType.PHI, this::allSameListener);
registry.put(InsnType.MERGE, this::allSameListener); registry.put(InsnType.MERGE, this::allSameListener);
registry.put(InsnType.AGET, this::arrayGetListener); registry.put(InsnType.AGET, this::arrayGetListener);
...@@ -239,6 +261,27 @@ public final class TypeUpdate { ...@@ -239,6 +261,27 @@ public final class TypeUpdate {
return updateTypeChecked(updateInfo, changeArg, candidateType); return updateTypeChecked(updateInfo, changeArg, candidateType);
} }
private TypeUpdateResult moveListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) {
boolean assignChanged = isAssign(insn, arg);
InsnArg changeArg = assignChanged ? insn.getArg(0) : insn.getResult();
TypeUpdateResult result = updateTypeChecked(updateInfo, changeArg, candidateType);
if (result == REJECT && changeArg.getType().isTypeKnown()) {
// allow result to be wider
if (assignChanged) {
TypeCompareEnum compareTypes = comparator.compareTypes(candidateType, changeArg.getType());
if (compareTypes.isWider() && inBounds(changeArg, candidateType)) {
return CHANGED;
}
} else {
TypeCompareEnum compareTypes = comparator.compareTypes(changeArg.getType(), candidateType);
if (compareTypes.isWider() && inBounds(changeArg, candidateType)) {
return CHANGED;
}
}
}
return result;
}
/** /**
* All args must have same types * All args must have same types
*/ */
......
...@@ -18,6 +18,10 @@ public class TypeUpdateInfo { ...@@ -18,6 +18,10 @@ public class TypeUpdateInfo {
return updates.containsKey(arg); return updates.containsKey(arg);
} }
public void rollbackUpdate(InsnArg arg) {
updates.remove(arg);
}
public Map<InsnArg, ArgType> getUpdates() { public Map<InsnArg, ArgType> getUpdates() {
return updates; return updates;
} }
......
...@@ -72,7 +72,7 @@ public abstract class BaseExternalTest extends IntegrationTest { ...@@ -72,7 +72,7 @@ public abstract class BaseExternalTest extends IntegrationTest {
int processed = 0; int processed = 0;
for (ClassNode classNode : root.getClasses(true)) { for (ClassNode classNode : root.getClasses(true)) {
String clsFullName = classNode.getClassInfo().getFullName(); String clsFullName = classNode.getClassInfo().getFullName();
if (isMatch(clsFullName, clsPattern)) { if (clsFullName.equals(clsPattern)) {
if (processCls(mthPattern, passes, classNode)) { if (processCls(mthPattern, passes, classNode)) {
processed++; processed++;
} }
......
package jadx.tests.integration.trycatch;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestMultiExceptionCatch2 extends IntegrationTest {
public static class TestCls {
public void test(Constructor constructor) {
try {
constructor.newInstance();
} catch (IllegalAccessException | InstantiationException | InvocationTargetException e) {
e.printStackTrace();
}
}
}
@Test
public void test() {
commonChecks();
}
@Test
public void testNoDebug() {
noDebugInfo();
commonChecks();
}
private void commonChecks() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("try {"));
assertThat(code, containsOne("} catch (IllegalAccessException | InstantiationException | InvocationTargetException e) {"));
assertThat(code, containsOne("e.printStackTrace();"));
// TODO: store vararg attribute for methods from classpath
// assertThat(code, containsOne("constructor.newInstance();"));
}
}
package jadx.tests.integration.types;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestTypeInheritance extends IntegrationTest {
public static class TestCls {
public interface IRoot {
}
public interface IBase extends IRoot {
}
public static class A implements IBase {
}
public static class B implements IBase {
public void b() {}
}
private static void test(boolean z) {
IBase impl;
if (z) {
impl = new A();
} else {
B b = new B();
b.b();
impl = b; // this move is removed in no-debug byte-code
}
useBase(impl);
useRoot(impl);
}
private static void useRoot(IRoot root) {}
private static void useBase(IBase base) {}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("IBase impl;"));
assertThat(code, containsOne("impl = new A();"));
assertThat(code, containsOne("B b = new B();"));
assertThat(code, containsOne("impl = b;"));
}
@Test
public void testNoDebug() {
noDebugInfo();
getClassNode(TestCls.class);
}
}
package jadx.tests.integration.types;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestTypeResolver6 extends IntegrationTest {
public static class TestCls {
private final Object obj;
public TestCls(boolean b) {
this.obj = b ? this : makeObj();
}
public Object makeObj() {
return new Object();
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("this.obj = b ? this : makeObj();"));
}
@Test
public void testNoDebug() {
noDebugInfo();
getClassNode(TestCls.class);
}
}
package jadx.tests.integration.types;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestTypeResolver6a extends IntegrationTest {
public static class TestCls implements Runnable {
private final Runnable runnable;
public TestCls(boolean b) {
this.runnable = b ? this : makeRunnable();
}
public Runnable makeRunnable() {
return new Runnable() {
@Override
public void run() {
// do nothing
}
};
}
@Override
public void run() {
// do nothing
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("this.runnable = b ? this : makeRunnable();"));
}
@Test
public void testNoDebug() {
noDebugInfo();
getClassNode(TestCls.class);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment