Commit 4ce5cc84 authored by Skylot's avatar Skylot

fix: use multi-variable type search algorithm if type propagation is failed

parent 9b091b7c
...@@ -69,11 +69,24 @@ public class ClspGraph { ...@@ -69,11 +69,24 @@ public class ClspGraph {
return nClass; return nClass;
} }
/**
* @return {@code clsName} instanceof {@code implClsName}
*/
public boolean isImplements(String clsName, String implClsName) { public boolean isImplements(String clsName, String implClsName) {
Set<String> anc = getAncestors(clsName); Set<String> anc = getAncestors(clsName);
return anc.contains(implClsName); return anc.contains(implClsName);
} }
public List<String> getImplementations(String clsName) {
List<String> list = new ArrayList<>();
for (String cls : nameMap.keySet()) {
if (isImplements(cls, clsName)) {
list.add(cls);
}
}
return list;
}
public String getCommonAncestor(String clsName, String implClsName) { public String getCommonAncestor(String clsName, String implClsName) {
if (clsName.equals(implClsName)) { if (clsName.equals(implClsName)) {
return clsName; return clsName;
...@@ -104,7 +117,7 @@ public class ClspGraph { ...@@ -104,7 +117,7 @@ public class ClspGraph {
return null; return null;
} }
private Set<String> getAncestors(String clsName) { public Set<String> getAncestors(String clsName) {
Set<String> result = ancestorCache.get(clsName); Set<String> result = ancestorCache.get(clsName);
if (result != null) { if (result != null) {
return result; return result;
......
...@@ -63,12 +63,7 @@ public class AnnotationGen { ...@@ -63,12 +63,7 @@ public class AnnotationGen {
} }
for (Annotation a : aList.getAll()) { for (Annotation a : aList.getAll()) {
String aCls = a.getAnnotationClass(); String aCls = a.getAnnotationClass();
if (aCls.startsWith(Consts.DALVIK_ANNOTATION_PKG)) { if (!aCls.startsWith(Consts.DALVIK_ANNOTATION_PKG)) {
// skip
if (Consts.DEBUG) {
code.startLine("// " + a);
}
} else {
code.startLine(); code.startLine();
formatAnnotation(code, a); formatAnnotation(code, a);
} }
......
...@@ -701,7 +701,7 @@ public class InsnGen { ...@@ -701,7 +701,7 @@ public class InsnGen {
ArgType origType; ArgType origType;
List<RegisterArg> arguments = callMth.getArguments(false); List<RegisterArg> arguments = callMth.getArguments(false);
if (arguments == null || arguments.isEmpty()) { if (arguments == null || arguments.isEmpty()) {
mth.addComment("JADX WARN: used method not loaded: " + callMth + ", types can be incorrect"); mth.addComment("JADX INFO: used method not loaded: " + callMth + ", types can be incorrect");
origType = callMth.getMethodInfo().getArgumentsTypes().get(origPos); origType = callMth.getMethodInfo().getArgumentsTypes().get(origPos);
} else { } else {
origType = arguments.get(origPos).getInitType(); origType = arguments.get(origPos).getInitType();
......
package jadx.core.dex.instructions; package jadx.core.dex.instructions;
import java.util.Objects;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils; import jadx.core.utils.InsnUtils;
import jadx.core.utils.Utils;
public class IndexInsnNode extends InsnNode { public class IndexInsnNode extends InsnNode {
...@@ -30,11 +33,21 @@ public class IndexInsnNode extends InsnNode { ...@@ -30,11 +33,21 @@ public class IndexInsnNode extends InsnNode {
return false; return false;
} }
IndexInsnNode other = (IndexInsnNode) obj; IndexInsnNode other = (IndexInsnNode) obj;
return index == null ? other.index == null : index.equals(other.index); return Objects.equals(index, other.index);
} }
@Override @Override
public String toString() { public String toString() {
return super.toString() + " " + InsnUtils.indexToString(index); switch (insnType) {
case CAST:
case CHECK_CAST:
return InsnUtils.formatOffset(offset) + ": "
+ InsnUtils.insnTypeToString(insnType)
+ getResult() + " = (" + InsnUtils.indexToString(index) + ") "
+ Utils.listToString(getArguments());
default:
return super.toString() + " " + InsnUtils.indexToString(index);
}
} }
} }
...@@ -693,11 +693,19 @@ public class InsnDecoder { ...@@ -693,11 +693,19 @@ public class InsnDecoder {
} }
private InsnNode arith(DecodedInstruction insn, ArithOp op, ArgType type) { private InsnNode arith(DecodedInstruction insn, ArithOp op, ArgType type) {
return new ArithNode(insn, op, type, false); return new ArithNode(insn, op, fixTypeForBitOps(op, type), false);
} }
private InsnNode arithLit(DecodedInstruction insn, ArithOp op, ArgType type) { private InsnNode arithLit(DecodedInstruction insn, ArithOp op, ArgType type) {
return new ArithNode(insn, op, type, true); return new ArithNode(insn, op, fixTypeForBitOps(op, type), true);
}
private ArgType fixTypeForBitOps(ArithOp op, ArgType type) {
if (type == ArgType.INT
&& (op == ArithOp.AND || op == ArithOp.OR || op == ArithOp.XOR)) {
return ArgType.NARROW_NUMBERS_NO_FLOAT;
}
return type;
} }
private InsnNode neg(DecodedInstruction insn, ArgType type) { private InsnNode neg(DecodedInstruction insn, ArgType type) {
......
...@@ -62,6 +62,7 @@ public final class PhiInsn extends InsnNode { ...@@ -62,6 +62,7 @@ public final class PhiInsn extends InsnNode {
RegisterArg reg = (RegisterArg) arg; RegisterArg reg = (RegisterArg) arg;
if (super.removeArg(reg)) { if (super.removeArg(reg)) {
blockBinds.remove(reg); blockBinds.remove(reg);
reg.getSVar().removeUse(reg);
InstructionRemover.fixUsedInPhiFlag(reg); InstructionRemover.fixUsedInPhiFlag(reg);
return true; return true;
} }
...@@ -78,7 +79,9 @@ public final class PhiInsn extends InsnNode { ...@@ -78,7 +79,9 @@ public final class PhiInsn extends InsnNode {
throw new JadxRuntimeException("Unknown predecessor block by arg " + from + " in PHI: " + this); throw new JadxRuntimeException("Unknown predecessor block by arg " + from + " in PHI: " + this);
} }
if (removeArg(from)) { if (removeArg(from)) {
bindArg((RegisterArg) to, pred); RegisterArg reg = (RegisterArg) to;
bindArg(reg, pred);
reg.getSVar().setUsedInPhi(this);
} }
return true; return true;
} }
......
...@@ -47,6 +47,10 @@ public abstract class ArgType { ...@@ -47,6 +47,10 @@ public abstract class ArgType {
PrimitiveType.INT, PrimitiveType.FLOAT, PrimitiveType.INT, PrimitiveType.FLOAT,
PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.CHAR); PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.CHAR);
public static final ArgType NARROW_NUMBERS_NO_FLOAT = unknown(
PrimitiveType.INT, PrimitiveType.BOOLEAN,
PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.CHAR);
public static final ArgType WIDE = unknown(PrimitiveType.LONG, PrimitiveType.DOUBLE); public static final ArgType WIDE = unknown(PrimitiveType.LONG, PrimitiveType.DOUBLE);
public static final ArgType INT_FLOAT = unknown(PrimitiveType.INT, PrimitiveType.FLOAT); public static final ArgType INT_FLOAT = unknown(PrimitiveType.INT, PrimitiveType.FLOAT);
......
...@@ -3,6 +3,7 @@ package jadx.core.dex.instructions.args; ...@@ -3,6 +3,7 @@ package jadx.core.dex.instructions.args;
import java.util.Objects; import java.util.Objects;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn; import jadx.core.dex.instructions.PhiInsn;
...@@ -118,7 +119,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -118,7 +119,7 @@ public class RegisterArg extends InsnArg implements Named {
return duplicate(getRegNum(), sVar); return duplicate(getRegNum(), sVar);
} }
public RegisterArg duplicate(int regNum, SSAVar sVar) { public RegisterArg duplicate(int regNum, @Nullable SSAVar sVar) {
RegisterArg dup = new RegisterArg(regNum, getInitType()); RegisterArg dup = new RegisterArg(regNum, getInitType());
if (sVar != null) { if (sVar != null) {
dup.setSVar(sVar); dup.setSVar(sVar);
...@@ -140,6 +141,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -140,6 +141,7 @@ public class RegisterArg extends InsnArg implements Named {
return InsnUtils.getConstValueByInsn(dex, parInsn); return InsnUtils.getConstValueByInsn(dex, parInsn);
} }
@Nullable
public InsnNode getAssignInsn() { public InsnNode getAssignInsn() {
if (sVar == null) { if (sVar == null) {
return null; return null;
...@@ -196,7 +198,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -196,7 +198,7 @@ public class RegisterArg extends InsnArg implements Named {
sb.append("(r"); sb.append("(r");
sb.append(regNum); sb.append(regNum);
if (sVar != null) { if (sVar != null) {
sb.append(':').append(sVar.getVersion()); sb.append('v').append(sVar.getVersion());
} }
if (getName() != null) { if (getName() != null) {
sb.append(" '").append(getName()).append('\''); sb.append(" '").append(getName()).append('\'');
......
...@@ -163,7 +163,7 @@ public class SSAVar extends AttrNode { ...@@ -163,7 +163,7 @@ public class SSAVar extends AttrNode {
} }
public String toShortString() { public String toShortString() {
return "r" + regNum + ":" + version; return "r" + regNum + "v" + version;
} }
@Override @Override
......
...@@ -575,6 +575,10 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode { ...@@ -575,6 +575,10 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode {
return debugInfoOffset; return debugInfoOffset;
} }
public SSAVar makeNewSVar(int regNum, @NotNull RegisterArg assignArg) {
return makeNewSVar(regNum, getNextSVarVersion(regNum), assignArg);
}
public SSAVar makeNewSVar(int regNum, int version, @NotNull RegisterArg assignArg) { public SSAVar makeNewSVar(int regNum, int version, @NotNull RegisterArg assignArg) {
SSAVar var = new SSAVar(regNum, version, assignArg); SSAVar var = new SSAVar(regNum, version, assignArg);
if (sVars.isEmpty()) { if (sVars.isEmpty()) {
......
...@@ -59,6 +59,7 @@ public class BlockExceptionHandler extends AbstractVisitor { ...@@ -59,6 +59,7 @@ public class BlockExceptionHandler extends AbstractVisitor {
me.add(AFlag.DONT_INLINE); me.add(AFlag.DONT_INLINE);
resArg.add(AFlag.CUSTOM_DECLARE); resArg.add(AFlag.CUSTOM_DECLARE);
excHandler.setArg(resArg); excHandler.setArg(resArg);
me.addAttr(handlerAttr);
return; return;
} }
} }
......
...@@ -53,11 +53,24 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { ...@@ -53,11 +53,24 @@ public class DebugInfoApplyVisitor extends AbstractVisitor {
applyDebugInfo(mth); applyDebugInfo(mth);
mth.remove(AType.LOCAL_VARS_DEBUG_INFO); mth.remove(AType.LOCAL_VARS_DEBUG_INFO);
} }
checkTypes(mth);
} catch (Exception e) { } catch (Exception e) {
LOG.error("Error to apply debug info: {}", ErrorsCounter.formatMsg(mth, e.getMessage()), e); LOG.error("Error to apply debug info: {}", ErrorsCounter.formatMsg(mth, e.getMessage()), e);
} }
} }
private static void checkTypes(MethodNode mth) {
if (mth.isNoCode() || mth.getSVars().isEmpty()) {
return;
}
mth.getSVars().forEach(var -> {
ArgType type = var.getTypeInfo().getType();
if (!type.isTypeKnown()) {
mth.addComment("JADX WARNING: type inference failed for: " + var.getDetailedVarInfo(mth));
}
});
}
private static void applyDebugInfo(MethodNode mth) { private static void applyDebugInfo(MethodNode mth) {
mth.getSVars().forEach(ssaVar -> collectVarDebugInfo(mth, ssaVar)); mth.getSVars().forEach(ssaVar -> collectVarDebugInfo(mth, ssaVar));
...@@ -80,6 +93,9 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { ...@@ -80,6 +93,9 @@ public class DebugInfoApplyVisitor extends AbstractVisitor {
applyDebugInfo(mth, ssaVar, debugInfo.getRegType(), debugInfo.getName()); applyDebugInfo(mth, ssaVar, debugInfo.getRegType(), debugInfo.getName());
} else { } else {
LOG.warn("Multiple debug info for {}: {}", ssaVar, debugInfoSet); LOG.warn("Multiple debug info for {}: {}", ssaVar, debugInfoSet);
for (RegDebugInfoAttr debugInfo : debugInfoSet) {
applyDebugInfo(mth, ssaVar, debugInfo.getRegType(), debugInfo.getName());
}
} }
} }
...@@ -102,7 +118,7 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { ...@@ -102,7 +118,7 @@ public class DebugInfoApplyVisitor extends AbstractVisitor {
int startAddr = localVar.getStartAddr(); int startAddr = localVar.getStartAddr();
int endAddr = localVar.getEndAddr(); int endAddr = localVar.getEndAddr();
if (isInside(startOffset, startAddr, endAddr) || isInside(endOffset, startAddr, endAddr)) { if (isInside(startOffset, startAddr, endAddr) || isInside(endOffset, startAddr, endAddr)) {
if (Consts.DEBUG && LOG.isDebugEnabled()) { if (Consts.DEBUG) {
LOG.debug("Apply debug info by offset for: {} to {}", ssaVar, localVar); LOG.debug("Apply debug info by offset for: {} to {}", ssaVar, localVar);
} }
applyDebugInfo(mth, ssaVar, localVar.getType(), localVar.getName()); applyDebugInfo(mth, ssaVar, localVar.getType(), localVar.getName());
...@@ -127,15 +143,15 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { ...@@ -127,15 +143,15 @@ public class DebugInfoApplyVisitor extends AbstractVisitor {
} }
public static void applyDebugInfo(MethodNode mth, SSAVar ssaVar, ArgType type, String varName) { public static void applyDebugInfo(MethodNode mth, SSAVar ssaVar, ArgType type, String varName) {
if (NameMapper.isValidIdentifier(varName)) { TypeUpdateResult result = mth.root().getTypeUpdate().applyWithWiderAllow(ssaVar, type);
ssaVar.setName(varName);
}
TypeUpdateResult result = mth.root().getTypeUpdate().applyDebug(ssaVar, type);
if (result == TypeUpdateResult.REJECT) { if (result == TypeUpdateResult.REJECT) {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("Reject debug info of type: {} and name: '{}' for {}, mth: {}", type, varName, ssaVar, mth); LOG.debug("Reject debug info of type: {} and name: '{}' for {}, mth: {}", type, varName, ssaVar, mth);
} }
} else { } else {
if (NameMapper.isValidIdentifier(varName)) {
ssaVar.setName(varName);
}
detachDebugInfo(ssaVar.getAssign()); detachDebugInfo(ssaVar.getAssign());
ssaVar.getUseList().forEach(DebugInfoApplyVisitor::detachDebugInfo); ssaVar.getUseList().forEach(DebugInfoApplyVisitor::detachDebugInfo);
} }
......
...@@ -56,7 +56,7 @@ public class DebugInfoParseVisitor extends AbstractVisitor { ...@@ -56,7 +56,7 @@ public class DebugInfoParseVisitor extends AbstractVisitor {
if (localVars.isEmpty()) { if (localVars.isEmpty()) {
return; return;
} }
if (Consts.DEBUG && LOG.isDebugEnabled()) { if (Consts.DEBUG) {
LOG.debug("Parsed debug info for {}: ", mth); LOG.debug("Parsed debug info for {}: ", mth);
localVars.forEach(v -> LOG.debug(" {}", v)); localVars.forEach(v -> LOG.debug(" {}", v));
} }
......
...@@ -38,6 +38,7 @@ import jadx.core.dex.visitors.JadxVisitor; ...@@ -38,6 +38,7 @@ import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.regions.variables.ProcessVariables; import jadx.core.dex.visitors.regions.variables.ProcessVariables;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.RegionUtils; import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxOverflowException;
@JadxVisitor( @JadxVisitor(
name = "LoopRegionVisitor", name = "LoopRegionVisitor",
...@@ -112,8 +113,12 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -112,8 +113,12 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
List<RegisterArg> args = new LinkedList<>(); List<RegisterArg> args = new LinkedList<>();
incrInsn.getRegisterArgs(args); incrInsn.getRegisterArgs(args);
for (RegisterArg iArg : args) { for (RegisterArg iArg : args) {
if (assignOnlyInLoop(mth, loopRegion, iArg)) { try {
return false; if (assignOnlyInLoop(mth, loopRegion, iArg)) {
return false;
}
} catch (StackOverflowError error) {
throw new JadxOverflowException("LoopRegionVisitor.assignOnlyInLoop endless recursion");
} }
} }
......
...@@ -111,7 +111,7 @@ public class EliminatePhiNodes extends AbstractVisitor { ...@@ -111,7 +111,7 @@ public class EliminatePhiNodes extends AbstractVisitor {
// all checks passed // all checks passed
RegisterArg newAssignArg = oldArg.duplicate(newRegNum, null); RegisterArg newAssignArg = oldArg.duplicate(newRegNum, null);
SSAVar newSVar = mth.makeNewSVar(newRegNum, mth.getNextSVarVersion(newRegNum), newAssignArg); SSAVar newSVar = mth.makeNewSVar(newRegNum, newAssignArg);
newSVar.setName(oldSVar.getName()); newSVar.setName(oldSVar.getName());
mth.root().getTypeUpdate().apply(newSVar, assignArg.getType()); mth.root().getTypeUpdate().apply(newSVar, assignArg.getType());
......
package jadx.core.dex.visitors.typeinference;
import java.util.ArrayList;
import java.util.List;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.Utils;
public abstract class AbstractTypeConstraint implements ITypeConstraint {
protected InsnNode insn;
protected List<SSAVar> relatedVars;
public AbstractTypeConstraint(InsnNode insn, InsnArg arg) {
this.insn = insn;
this.relatedVars = collectRelatedVars(insn, arg);
}
private List<SSAVar> collectRelatedVars(InsnNode insn, InsnArg arg) {
List<SSAVar> list = new ArrayList<>(insn.getArgsCount());
if (insn.getResult() == arg) {
for (InsnArg insnArg : insn.getArguments()) {
if (insnArg.isRegister()) {
list.add(((RegisterArg) insnArg).getSVar());
}
}
} else {
list.add(insn.getResult().getSVar());
for (InsnArg insnArg : insn.getArguments()) {
if (insnArg != arg && insnArg.isRegister()) {
list.add(((RegisterArg) insnArg).getSVar());
}
}
}
return list;
}
@Override
public List<SSAVar> getRelatedVars() {
return relatedVars;
}
@Override
public String toString() {
return "(" + insn.getType() + ":" + Utils.listToString(relatedVars, SSAVar::toShortString) + ")";
}
}
package jadx.core.dex.visitors.typeinference;
import java.util.List;
import jadx.core.dex.instructions.args.SSAVar;
public interface ITypeConstraint {
List<SSAVar> getRelatedVars();
boolean check(TypeSearchState state);
}
...@@ -30,4 +30,12 @@ public enum TypeCompareEnum { ...@@ -30,4 +30,12 @@ public enum TypeCompareEnum {
return this; return this;
} }
} }
public boolean isWider() {
return this == WIDER || this == WIDER_BY_GENERIC;
}
public boolean isNarrow() {
return this == NARROW || this == NARROW_BY_GENERIC;
}
} }
...@@ -3,6 +3,8 @@ package jadx.core.dex.visitors.typeinference; ...@@ -3,6 +3,8 @@ package jadx.core.dex.visitors.typeinference;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import org.jetbrains.annotations.NotNull;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
public class TypeInfo { public class TypeInfo {
...@@ -10,6 +12,7 @@ public class TypeInfo { ...@@ -10,6 +12,7 @@ public class TypeInfo {
private final Set<ITypeBound> bounds = new HashSet<>(); private final Set<ITypeBound> bounds = new HashSet<>();
@NotNull
public ArgType getType() { public ArgType getType() {
return type; return type;
} }
......
package jadx.core.dex.visitors.typeinference;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class TypeSearchState {
private Map<SSAVar, TypeSearchVarInfo> varInfoMap;
public TypeSearchState(MethodNode mth) {
List<SSAVar> vars = mth.getSVars();
this.varInfoMap = new LinkedHashMap<>(vars.size());
for (SSAVar var : vars) {
varInfoMap.put(var, new TypeSearchVarInfo(var));
}
}
@NotNull
public TypeSearchVarInfo getVarInfo(SSAVar var) {
TypeSearchVarInfo varInfo = this.varInfoMap.get(var);
if (varInfo == null) {
throw new JadxRuntimeException("TypeSearchVarInfo not found in map for var: " + var);
}
return varInfo;
}
public List<TypeSearchVarInfo> getAllVars() {
return new ArrayList<>(varInfoMap.values());
}
public List<TypeSearchVarInfo> getUnresolvedVars() {
return varInfoMap.values().stream()
.filter(varInfo -> !varInfo.isTypeResolved())
.collect(Collectors.toList());
}
public List<TypeSearchVarInfo> getResolvedVars() {
return varInfoMap.values().stream()
.filter(TypeSearchVarInfo::isTypeResolved)
.collect(Collectors.toList());
}
}
package jadx.core.dex.visitors.typeinference;
import java.util.List;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.SSAVar;
public class TypeSearchVarInfo {
private final SSAVar var;
private boolean typeResolved;
private ArgType currentType;
private List<ArgType> candidateTypes;
private int currentIndex = -1;
private List<ITypeConstraint> constraints;
public TypeSearchVarInfo(SSAVar var) {
this.var = var;
}
public void reset() {
if (typeResolved) {
return;
}
currentIndex = 0;
currentType = candidateTypes.get(0);
}
/**
* Switch {@code currentType} to next candidate
*
* @return true - if this is the first candidate
*/
public boolean nextType() {
if (typeResolved) {
return false;
}
int len = candidateTypes.size();
currentIndex = (currentIndex + 1) % len;
currentType = candidateTypes.get(currentIndex);
return currentIndex == 0;
}
public SSAVar getVar() {
return var;
}
public boolean isTypeResolved() {
return typeResolved;
}
public void setTypeResolved(boolean typeResolved) {
this.typeResolved = typeResolved;
}
public ArgType getCurrentType() {
return currentType;
}
public void setCurrentType(ArgType currentType) {
this.currentType = currentType;
}
public List<ArgType> getCandidateTypes() {
return candidateTypes;
}
public void setCandidateTypes(List<ArgType> candidateTypes) {
this.candidateTypes = candidateTypes;
}
public List<ITypeConstraint> getConstraints() {
return constraints;
}
public void setConstraints(List<ITypeConstraint> constraints) {
this.constraints = constraints;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("TypeSearchVarInfo{");
sb.append(var.toShortString());
if (typeResolved) {
sb.append(", resolved type: ").append(currentType);
} else {
sb.append(", currentType=").append(currentType);
sb.append(", candidateTypes=").append(candidateTypes);
sb.append(", constraints=").append(constraints);
}
sb.append('}');
return sb.toString();
}
}
...@@ -34,22 +34,16 @@ public final class TypeUpdate { ...@@ -34,22 +34,16 @@ public final class TypeUpdate {
private final Map<InsnType, ITypeListener> listenerRegistry; private final Map<InsnType, ITypeListener> listenerRegistry;
private final TypeCompare comparator; private final TypeCompare comparator;
private ThreadLocal<Boolean> applyDebug = new ThreadLocal<>(); private ThreadLocal<Boolean> allowWider = new ThreadLocal<>();
public TypeUpdate(RootNode root) { public TypeUpdate(RootNode root) {
this.listenerRegistry = initListenerRegistry(); this.listenerRegistry = initListenerRegistry();
this.comparator = new TypeCompare(root); this.comparator = new TypeCompare(root);
} }
public TypeUpdateResult applyDebug(SSAVar ssaVar, ArgType candidateType) { /**
try { * Perform recursive type checking and type propagation for all related variables
applyDebug.set(true); */
return apply(ssaVar, candidateType);
} finally {
applyDebug.set(false);
}
}
public TypeUpdateResult apply(SSAVar ssaVar, ArgType candidateType) { public TypeUpdateResult apply(SSAVar ssaVar, ArgType candidateType) {
if (candidateType == null) { if (candidateType == null) {
return REJECT; return REJECT;
...@@ -71,10 +65,21 @@ public final class TypeUpdate { ...@@ -71,10 +65,21 @@ public final class TypeUpdate {
return CHANGED; return CHANGED;
} }
/**
* Allow wider types for apply from debug info and some special cases
*/
public TypeUpdateResult applyWithWiderAllow(SSAVar ssaVar, ArgType candidateType) {
try {
allowWider.set(true);
return apply(ssaVar, candidateType);
} finally {
allowWider.set(false);
}
}
private TypeUpdateResult updateTypeChecked(TypeUpdateInfo updateInfo, InsnArg arg, ArgType candidateType) { private TypeUpdateResult updateTypeChecked(TypeUpdateInfo updateInfo, InsnArg arg, ArgType candidateType) {
if (candidateType == null) { if (candidateType == null) {
LOG.warn("Reject null type update, arg: {}, info: {}", arg, updateInfo, new RuntimeException()); throw new JadxRuntimeException("Null type update for arg: " + arg);
return REJECT;
} }
ArgType currentType = arg.getType(); ArgType currentType = arg.getType();
if (Objects.equals(currentType, candidateType)) { if (Objects.equals(currentType, candidateType)) {
...@@ -82,15 +87,20 @@ public final class TypeUpdate { ...@@ -82,15 +87,20 @@ public final class TypeUpdate {
} }
TypeCompareEnum compareResult = comparator.compareTypes(candidateType, currentType); TypeCompareEnum compareResult = comparator.compareTypes(candidateType, currentType);
if (compareResult == TypeCompareEnum.CONFLICT) { if (compareResult == TypeCompareEnum.CONFLICT) {
if (Consts.DEBUG) {
LOG.debug("Type rejected for {} due to conflict: candidate={}, current={}", arg, candidateType, currentType);
}
return REJECT; return REJECT;
} }
if (arg.isTypeImmutable() && currentType != ArgType.UNKNOWN) { if (arg.isTypeImmutable() && currentType != ArgType.UNKNOWN) {
// don't changed type, conflict already rejected // don't changed type, conflict already rejected
return SAME; return SAME;
} }
if (compareResult == TypeCompareEnum.WIDER || compareResult == TypeCompareEnum.WIDER_BY_GENERIC) { if (compareResult.isWider()) {
// allow wider types for apply from debug info if (allowWider.get() != Boolean.TRUE) {
if (applyDebug.get() != Boolean.TRUE) { if (Consts.DEBUG) {
LOG.debug("Type rejected for {}: candidate={} is wider than current={}", arg, candidateType, currentType);
}
return REJECT; return REJECT;
} }
} }
...@@ -104,7 +114,7 @@ public final class TypeUpdate { ...@@ -104,7 +114,7 @@ public final class TypeUpdate {
private TypeUpdateResult updateTypeForSsaVar(TypeUpdateInfo updateInfo, SSAVar ssaVar, ArgType candidateType) { private TypeUpdateResult updateTypeForSsaVar(TypeUpdateInfo updateInfo, SSAVar ssaVar, ArgType candidateType) {
TypeInfo typeInfo = ssaVar.getTypeInfo(); TypeInfo typeInfo = ssaVar.getTypeInfo();
if (!inBounds(typeInfo.getBounds(), candidateType)) { if (!inBounds(typeInfo.getBounds(), candidateType)) {
if (Consts.DEBUG && LOG.isDebugEnabled()) { if (Consts.DEBUG) {
LOG.debug("Reject type '{}' for {} by bounds: {}", candidateType, ssaVar, typeInfo.getBounds()); LOG.debug("Reject type '{}' for {} by bounds: {}", candidateType, ssaVar, typeInfo.getBounds());
} }
return REJECT; return REJECT;
...@@ -138,7 +148,11 @@ public final class TypeUpdate { ...@@ -138,7 +148,11 @@ public final class TypeUpdate {
} }
updateInfo.requestUpdate(arg, candidateType); updateInfo.requestUpdate(arg, candidateType);
try { try {
return runListeners(updateInfo, arg, candidateType); TypeUpdateResult result = runListeners(updateInfo, arg, candidateType);
if (result == REJECT) {
updateInfo.rollbackUpdate(arg);
}
return result;
} catch (StackOverflowError overflow) { } catch (StackOverflowError overflow) {
throw new JadxOverflowException("Type update terminated with stack overflow, arg: " + arg); throw new JadxOverflowException("Type update terminated with stack overflow, arg: " + arg);
} }
...@@ -156,7 +170,7 @@ public final class TypeUpdate { ...@@ -156,7 +170,7 @@ public final class TypeUpdate {
return listener.update(updateInfo, insn, arg, candidateType); return listener.update(updateInfo, insn, arg, candidateType);
} }
private boolean inBounds(Set<ITypeBound> bounds, ArgType candidateType) { boolean inBounds(Set<ITypeBound> bounds, ArgType candidateType) {
for (ITypeBound bound : bounds) { for (ITypeBound bound : bounds) {
ArgType boundType = bound.getType(); ArgType boundType = bound.getType();
if (boundType != null && !checkBound(candidateType, bound, boundType)) { if (boundType != null && !checkBound(candidateType, bound, boundType)) {
...@@ -166,6 +180,14 @@ public final class TypeUpdate { ...@@ -166,6 +180,14 @@ public final class TypeUpdate {
return true; return true;
} }
private boolean inBounds(InsnArg arg, ArgType candidateType) {
if (arg.isRegister()) {
TypeInfo typeInfo = ((RegisterArg) arg).getSVar().getTypeInfo();
return inBounds(typeInfo.getBounds(), candidateType);
}
return arg.getType().equals(candidateType);
}
private boolean checkBound(ArgType candidateType, ITypeBound bound, ArgType boundType) { private boolean checkBound(ArgType candidateType, ITypeBound bound, ArgType boundType) {
TypeCompareEnum compareResult = comparator.compareTypes(candidateType, boundType); TypeCompareEnum compareResult = comparator.compareTypes(candidateType, boundType);
switch (compareResult) { switch (compareResult) {
...@@ -222,7 +244,7 @@ public final class TypeUpdate { ...@@ -222,7 +244,7 @@ public final class TypeUpdate {
private Map<InsnType, ITypeListener> initListenerRegistry() { private Map<InsnType, ITypeListener> initListenerRegistry() {
Map<InsnType, ITypeListener> registry = new EnumMap<>(InsnType.class); Map<InsnType, ITypeListener> registry = new EnumMap<>(InsnType.class);
registry.put(InsnType.CONST, this::sameFirstArgListener); registry.put(InsnType.CONST, this::sameFirstArgListener);
registry.put(InsnType.MOVE, this::sameFirstArgListener); registry.put(InsnType.MOVE, this::moveListener);
registry.put(InsnType.PHI, this::allSameListener); registry.put(InsnType.PHI, this::allSameListener);
registry.put(InsnType.MERGE, this::allSameListener); registry.put(InsnType.MERGE, this::allSameListener);
registry.put(InsnType.AGET, this::arrayGetListener); registry.put(InsnType.AGET, this::arrayGetListener);
...@@ -239,6 +261,27 @@ public final class TypeUpdate { ...@@ -239,6 +261,27 @@ public final class TypeUpdate {
return updateTypeChecked(updateInfo, changeArg, candidateType); return updateTypeChecked(updateInfo, changeArg, candidateType);
} }
private TypeUpdateResult moveListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) {
boolean assignChanged = isAssign(insn, arg);
InsnArg changeArg = assignChanged ? insn.getArg(0) : insn.getResult();
TypeUpdateResult result = updateTypeChecked(updateInfo, changeArg, candidateType);
if (result == REJECT && changeArg.getType().isTypeKnown()) {
// allow result to be wider
if (assignChanged) {
TypeCompareEnum compareTypes = comparator.compareTypes(candidateType, changeArg.getType());
if (compareTypes.isWider() && inBounds(changeArg, candidateType)) {
return CHANGED;
}
} else {
TypeCompareEnum compareTypes = comparator.compareTypes(changeArg.getType(), candidateType);
if (compareTypes.isWider() && inBounds(changeArg, candidateType)) {
return CHANGED;
}
}
}
return result;
}
/** /**
* All args must have same types * All args must have same types
*/ */
......
...@@ -18,6 +18,10 @@ public class TypeUpdateInfo { ...@@ -18,6 +18,10 @@ public class TypeUpdateInfo {
return updates.containsKey(arg); return updates.containsKey(arg);
} }
public void rollbackUpdate(InsnArg arg) {
updates.remove(arg);
}
public Map<InsnArg, ArgType> getUpdates() { public Map<InsnArg, ArgType> getUpdates() {
return updates; return updates;
} }
......
...@@ -72,7 +72,7 @@ public abstract class BaseExternalTest extends IntegrationTest { ...@@ -72,7 +72,7 @@ public abstract class BaseExternalTest extends IntegrationTest {
int processed = 0; int processed = 0;
for (ClassNode classNode : root.getClasses(true)) { for (ClassNode classNode : root.getClasses(true)) {
String clsFullName = classNode.getClassInfo().getFullName(); String clsFullName = classNode.getClassInfo().getFullName();
if (isMatch(clsFullName, clsPattern)) { if (clsFullName.equals(clsPattern)) {
if (processCls(mthPattern, passes, classNode)) { if (processCls(mthPattern, passes, classNode)) {
processed++; processed++;
} }
......
package jadx.tests.integration.trycatch;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestMultiExceptionCatch2 extends IntegrationTest {
public static class TestCls {
public void test(Constructor constructor) {
try {
constructor.newInstance();
} catch (IllegalAccessException | InstantiationException | InvocationTargetException e) {
e.printStackTrace();
}
}
}
@Test
public void test() {
commonChecks();
}
@Test
public void testNoDebug() {
noDebugInfo();
commonChecks();
}
private void commonChecks() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("try {"));
assertThat(code, containsOne("} catch (IllegalAccessException | InstantiationException | InvocationTargetException e) {"));
assertThat(code, containsOne("e.printStackTrace();"));
// TODO: store vararg attribute for methods from classpath
// assertThat(code, containsOne("constructor.newInstance();"));
}
}
package jadx.tests.integration.types;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestTypeInheritance extends IntegrationTest {
public static class TestCls {
public interface IRoot {
}
public interface IBase extends IRoot {
}
public static class A implements IBase {
}
public static class B implements IBase {
public void b() {}
}
private static void test(boolean z) {
IBase impl;
if (z) {
impl = new A();
} else {
B b = new B();
b.b();
impl = b; // this move is removed in no-debug byte-code
}
useBase(impl);
useRoot(impl);
}
private static void useRoot(IRoot root) {}
private static void useBase(IBase base) {}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("IBase impl;"));
assertThat(code, containsOne("impl = new A();"));
assertThat(code, containsOne("B b = new B();"));
assertThat(code, containsOne("impl = b;"));
}
@Test
public void testNoDebug() {
noDebugInfo();
getClassNode(TestCls.class);
}
}
package jadx.tests.integration.types;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestTypeResolver6 extends IntegrationTest {
public static class TestCls {
private final Object obj;
public TestCls(boolean b) {
this.obj = b ? this : makeObj();
}
public Object makeObj() {
return new Object();
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("this.obj = b ? this : makeObj();"));
}
@Test
public void testNoDebug() {
noDebugInfo();
getClassNode(TestCls.class);
}
}
package jadx.tests.integration.types;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestTypeResolver6a extends IntegrationTest {
public static class TestCls implements Runnable {
private final Runnable runnable;
public TestCls(boolean b) {
this.runnable = b ? this : makeRunnable();
}
public Runnable makeRunnable() {
return new Runnable() {
@Override
public void run() {
// do nothing
}
};
}
@Override
public void run() {
// do nothing
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("this.runnable = b ? this : makeRunnable();"));
}
@Test
public void testNoDebug() {
noDebugInfo();
getClassNode(TestCls.class);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment