Commit 3de04cb6 authored by Skylot's avatar Skylot

refactor: use flags to mark registers with immutable type

parent 68d074ae
...@@ -258,8 +258,8 @@ public class ClassGen { ...@@ -258,8 +258,8 @@ public class ClassGen {
addMethod(code, mth); addMethod(code, mth);
} catch (Exception e) { } catch (Exception e) {
code.newLine().add("/*"); code.newLine().add("/*");
code.newLine().add(ErrorsCounter.methodError(mth, "Method generation error", e)); code.newLine().addMultiLine(ErrorsCounter.methodError(mth, "Method generation error", e));
code.newLine().add(Utils.getStackTrace(e)); code.newLine().addMultiLine(Utils.getStackTrace(e));
code.newLine().add("*/"); code.newLine().add("*/");
code.setIndent(savedIndent); code.setIndent(savedIndent);
} }
......
...@@ -26,7 +26,16 @@ public enum AFlag { ...@@ -26,7 +26,16 @@ public enum AFlag {
ANONYMOUS_CLASS, ANONYMOUS_CLASS,
THIS, THIS,
METHOD_ARGUMENT, // RegisterArg attribute for method arguments
/**
* RegisterArg attribute for method arguments
*/
METHOD_ARGUMENT,
/**
* Type of RegisterArg or SSAVar can't be changed
*/
IMMUTABLE_TYPE,
CUSTOM_DECLARE, // variable for this register don't need declaration CUSTOM_DECLARE, // variable for this register don't need declaration
DECLARE_VAR, DECLARE_VAR,
......
...@@ -28,7 +28,7 @@ public final class PhiInsn extends InsnNode { ...@@ -28,7 +28,7 @@ public final class PhiInsn extends InsnNode {
} }
public RegisterArg bindArg(BlockNode pred) { public RegisterArg bindArg(BlockNode pred) {
RegisterArg arg = InsnArg.reg(getResult().getRegNum(), getResult().getType()); RegisterArg arg = InsnArg.reg(getResult().getRegNum(), getResult().getInitType());
bindArg(arg, pred); bindArg(arg, pred);
return arg; return arg;
} }
......
...@@ -6,7 +6,7 @@ import java.util.List; ...@@ -6,7 +6,7 @@ import java.util.List;
public class CodeVar { public class CodeVar {
private String name; private String name;
private ArgType type; // nullable before type inference, set only for immutable types private ArgType type; // before type inference can be null and set only for immutable types
private List<SSAVar> ssaVars = new ArrayList<>(3); private List<SSAVar> ssaVars = new ArrayList<>(3);
private boolean isFinal; private boolean isFinal;
......
...@@ -3,6 +3,7 @@ package jadx.core.dex.instructions.args; ...@@ -3,6 +3,7 @@ package jadx.core.dex.instructions.args;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import jadx.core.dex.info.FieldInfo; import jadx.core.dex.info.FieldInfo;
import jadx.core.utils.exceptions.JadxRuntimeException;
// TODO: don't extend RegisterArg (now used as a result of instruction) // TODO: don't extend RegisterArg (now used as a result of instruction)
public final class FieldArg extends RegisterArg { public final class FieldArg extends RegisterArg {
...@@ -13,7 +14,7 @@ public final class FieldArg extends RegisterArg { ...@@ -13,7 +14,7 @@ public final class FieldArg extends RegisterArg {
private final InsnArg instArg; private final InsnArg instArg;
public FieldArg(FieldInfo field, @Nullable InsnArg reg) { public FieldArg(FieldInfo field, @Nullable InsnArg reg) {
super(-1); super(-1, field.getType());
this.instArg = reg; this.instArg = reg;
this.field = field; this.field = field;
} }
...@@ -41,8 +42,18 @@ public final class FieldArg extends RegisterArg { ...@@ -41,8 +42,18 @@ public final class FieldArg extends RegisterArg {
} }
@Override @Override
public void setType(ArgType type) { public ArgType getType() {
this.type = type; return this.field.getType();
}
@Override
public ArgType getInitType() {
return this.field.getType();
}
@Override
public void setType(ArgType newType) {
throw new JadxRuntimeException("Can't set type for FieldArg");
} }
@Override @Override
......
...@@ -38,16 +38,20 @@ public abstract class InsnArg extends Typed { ...@@ -38,16 +38,20 @@ public abstract class InsnArg extends Typed {
return reg(InsnUtils.getArg(insn, argNum), type); return reg(InsnUtils.getArg(insn, argNum), type);
} }
public static TypeImmutableArg typeImmutableReg(int regNum, ArgType type) { public static RegisterArg typeImmutableReg(DecodedInstruction insn, int argNum, ArgType type) {
return new TypeImmutableArg(regNum, type); return typeImmutableReg(InsnUtils.getArg(insn, argNum), type);
} }
public static TypeImmutableArg typeImmutableReg(DecodedInstruction insn, int argNum, ArgType type) { public static RegisterArg typeImmutableReg(int regNum, ArgType type) {
return typeImmutableReg(InsnUtils.getArg(insn, argNum), type); return reg(regNum, type, true);
} }
public static RegisterArg reg(int regNum, ArgType type, boolean typeImmutable) { public static RegisterArg reg(int regNum, ArgType type, boolean typeImmutable) {
return typeImmutable ? new TypeImmutableArg(regNum, type) : new RegisterArg(regNum, type); RegisterArg reg = new RegisterArg(regNum, type);
if (typeImmutable) {
reg.add(AFlag.IMMUTABLE_TYPE);
}
return reg;
} }
public static LiteralArg lit(long literal, ArgType type) { public static LiteralArg lit(long literal, ArgType type) {
......
...@@ -4,24 +4,25 @@ import java.util.Objects; ...@@ -4,24 +4,25 @@ import java.util.Objects;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.nodes.DexNode; import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils; import jadx.core.utils.InsnUtils;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class RegisterArg extends InsnArg implements Named { public class RegisterArg extends InsnArg implements Named {
private static final Logger LOG = LoggerFactory.getLogger(RegisterArg.class);
public static final String THIS_ARG_NAME = "this"; public static final String THIS_ARG_NAME = "this";
protected final int regNum; protected final int regNum;
// not null after SSATransform pass // not null after SSATransform pass
private SSAVar sVar; private SSAVar sVar;
public RegisterArg(int rn) {
this.regNum = rn;
}
public RegisterArg(int rn, ArgType type) { public RegisterArg(int rn, ArgType type) {
this.type = type; this.type = type; // initial type, not changing, can be unknown
this.regNum = rn; this.regNum = rn;
} }
...@@ -35,19 +36,28 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -35,19 +36,28 @@ public class RegisterArg extends InsnArg implements Named {
} }
@Override @Override
public void setType(ArgType type) { public void setType(ArgType newType) {
if (sVar != null) { if (sVar == null) {
sVar.setType(type); throw new JadxRuntimeException("Can't change type for register without SSA variable: " + this);
}
if (contains(AFlag.IMMUTABLE_TYPE)) {
if (!type.isTypeKnown()) {
throw new JadxRuntimeException("Unknown immutable type '" + type + "' in " + this);
}
if (!type.equals(newType)) {
LOG.warn("JADX WARNING: Can't change immutable type from '{}' to '{}' for {}", type, newType, this);
return;
} }
} }
sVar.setType(newType);
}
@Override @Override
public ArgType getType() { public ArgType getType() {
SSAVar ssaVar = this.sVar; if (sVar != null) {
if (ssaVar != null) { return sVar.getTypeInfo().getType();
return ssaVar.getTypeInfo().getType();
} }
return ArgType.UNKNOWN; throw new JadxRuntimeException("Register type unknown, SSA variable not initialized: r" + regNum);
} }
public ArgType getInitType() { public ArgType getInitType() {
...@@ -56,14 +66,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -56,14 +66,7 @@ public class RegisterArg extends InsnArg implements Named {
@Override @Override
public boolean isTypeImmutable() { public boolean isTypeImmutable() {
if (sVar != null) { return contains(AFlag.IMMUTABLE_TYPE) || (sVar != null && sVar.contains(AFlag.IMMUTABLE_TYPE));
RegisterArg assign = sVar.getAssign();
if (assign == this) {
return false;
}
return assign.isTypeImmutable();
}
return false;
} }
public SSAVar getSVar() { public SSAVar getSVar() {
...@@ -72,6 +75,9 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -72,6 +75,9 @@ public class RegisterArg extends InsnArg implements Named {
void setSVar(@NotNull SSAVar sVar) { void setSVar(@NotNull SSAVar sVar) {
this.sVar = sVar; this.sVar = sVar;
if (contains(AFlag.IMMUTABLE_TYPE)) {
sVar.add(AFlag.IMMUTABLE_TYPE);
}
} }
public String getName() { public String getName() {
...@@ -98,21 +104,6 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -98,21 +104,6 @@ public class RegisterArg extends InsnArg implements Named {
return n.equals(((Named) arg).getName()); return n.equals(((Named) arg).getName());
} }
public void mergeName(InsnArg arg) {
if (arg instanceof Named) {
Named otherArg = (Named) arg;
String otherName = otherArg.getName();
String name = getName();
if (!Objects.equals(name, otherName)) {
if (name == null) {
setName(otherName);
} else if (otherName == null) {
otherArg.setName(name);
}
}
}
}
@Override @Override
public RegisterArg duplicate() { public RegisterArg duplicate() {
return duplicate(getRegNum(), sVar); return duplicate(getRegNum(), sVar);
...@@ -146,8 +137,8 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -146,8 +137,8 @@ public class RegisterArg extends InsnArg implements Named {
return sVar.getAssign().getParentInsn(); return sVar.getAssign().getParentInsn();
} }
public boolean equalRegister(RegisterArg arg) { public boolean equalRegisterAndType(RegisterArg arg) {
return regNum == arg.regNum; return regNum == arg.regNum && type.equals(arg.type);
} }
public boolean sameRegAndSVar(InsnArg arg) { public boolean sameRegAndSVar(InsnArg arg) {
...@@ -159,10 +150,6 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -159,10 +150,6 @@ public class RegisterArg extends InsnArg implements Named {
&& Objects.equals(sVar, reg.getSVar()); && Objects.equals(sVar, reg.getSVar());
} }
public boolean equalRegisterAndType(RegisterArg arg) {
return regNum == arg.regNum && type.equals(arg.type);
}
public boolean sameCodeVar(RegisterArg arg) { public boolean sameCodeVar(RegisterArg arg) {
return this.getSVar().getCodeVar() == arg.getSVar().getCodeVar(); return this.getSVar().getCodeVar() == arg.getSVar().getCodeVar();
} }
...@@ -182,7 +169,6 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -182,7 +169,6 @@ public class RegisterArg extends InsnArg implements Named {
} }
RegisterArg other = (RegisterArg) obj; RegisterArg other = (RegisterArg) obj;
return regNum == other.regNum return regNum == other.regNum
&& Objects.equals(getType(), other.getType())
&& Objects.equals(sVar, other.getSVar()); && Objects.equals(sVar, other.getSVar());
} }
...@@ -197,16 +183,18 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -197,16 +183,18 @@ public class RegisterArg extends InsnArg implements Named {
if (getName() != null) { if (getName() != null) {
sb.append(" '").append(getName()).append('\''); sb.append(" '").append(getName()).append('\'');
} }
ArgType type = getType(); ArgType type = sVar != null ? getType() : null;
if (type != null) {
sb.append(' ').append(type); sb.append(' ').append(type);
}
ArgType initType = getInitType(); ArgType initType = getInitType();
if (!type.equals(initType) && !type.isTypeKnown()) { if (type == null || (!type.equals(initType) && !type.isTypeKnown())) {
sb.append(" I:").append(initType); sb.append(" I:").append(initType);
} }
if (!isAttrStorageEmpty()) { if (!isAttrStorageEmpty()) {
sb.append(' ').append(getAttributesString()); sb.append(' ').append(getAttributesString());
} }
sb.append(")"); sb.append(')');
return sb.toString(); return sb.toString();
} }
} }
...@@ -29,7 +29,7 @@ public class SSAVar extends AttrNode { ...@@ -29,7 +29,7 @@ public class SSAVar extends AttrNode {
private TypeInfo typeInfo = new TypeInfo(); private TypeInfo typeInfo = new TypeInfo();
@Nullable("Set in EliminatePhiNodes pass") @Nullable("Set in InitCodeVariables pass")
private CodeVar codeVar; private CodeVar codeVar;
public SSAVar(int regNum, int v, @NotNull RegisterArg assign) { public SSAVar(int regNum, int v, @NotNull RegisterArg assign) {
...@@ -65,7 +65,8 @@ public class SSAVar extends AttrNode { ...@@ -65,7 +65,8 @@ public class SSAVar extends AttrNode {
return useList.size(); return useList.size();
} }
public void setType(ArgType type) { // must be used only from RegisterArg#setType()
void setType(ArgType type) {
typeInfo.setType(type); typeInfo.setType(type);
if (codeVar != null) { if (codeVar != null) {
codeVar.setType(type); codeVar.setType(type);
...@@ -139,9 +140,6 @@ public class SSAVar extends AttrNode { ...@@ -139,9 +140,6 @@ public class SSAVar extends AttrNode {
public void setCodeVar(@NotNull CodeVar codeVar) { public void setCodeVar(@NotNull CodeVar codeVar) {
this.codeVar = codeVar; this.codeVar = codeVar;
if (codeVar.getType() != null && !typeInfo.getType().equals(codeVar.getType())) {
throw new JadxRuntimeException("Unmached types for SSA and Code variables: " + this + " and " + codeVar);
}
codeVar.addSsaVar(this); codeVar.addSsaVar(this);
} }
......
package jadx.core.dex.instructions.args;
import java.util.Objects;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class TypeImmutableArg extends RegisterArg {
public TypeImmutableArg(int rn, ArgType type) {
super(rn, type);
}
@Override
public boolean isTypeImmutable() {
return true;
}
@Override
public void setType(ArgType type) {
// allow set only initial type
if (Objects.equals(this.type, type)) {
super.setType(type);
} else {
throw new JadxRuntimeException("Can't change arg with immutable type");
}
}
@Override
public RegisterArg duplicate() {
return duplicate(getRegNum(), getSVar());
}
@Override
public RegisterArg duplicate(int regNum, SSAVar sVar) {
RegisterArg dup = new TypeImmutableArg(regNum, getInitType());
if (sVar != null) {
dup.setSVar(sVar);
}
dup.copyAttributesFrom(this);
return dup;
}
}
...@@ -33,7 +33,6 @@ import jadx.core.dex.instructions.args.ArgType; ...@@ -33,7 +33,6 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.instructions.args.TypeImmutableArg;
import jadx.core.dex.nodes.parser.SignatureParser; import jadx.core.dex.nodes.parser.SignatureParser;
import jadx.core.dex.regions.Region; import jadx.core.dex.regions.Region;
import jadx.core.dex.trycatch.ExcHandlerAttr; import jadx.core.dex.trycatch.ExcHandlerAttr;
...@@ -220,8 +219,9 @@ public class MethodNode extends LineAttrNode implements ILoadable, ICodeNode { ...@@ -220,8 +219,9 @@ public class MethodNode extends LineAttrNode implements ILoadable, ICodeNode {
if (accFlags.isStatic()) { if (accFlags.isStatic()) {
thisArg = null; thisArg = null;
} else { } else {
TypeImmutableArg arg = InsnArg.typeImmutableReg(pos - 1, parentClass.getClassInfo().getType()); RegisterArg arg = InsnArg.reg(pos - 1, parentClass.getClassInfo().getType());
arg.add(AFlag.THIS); arg.add(AFlag.THIS);
arg.add(AFlag.IMMUTABLE_TYPE);
thisArg = arg; thisArg = arg;
} }
if (args.isEmpty()) { if (args.isEmpty()) {
...@@ -230,8 +230,9 @@ public class MethodNode extends LineAttrNode implements ILoadable, ICodeNode { ...@@ -230,8 +230,9 @@ public class MethodNode extends LineAttrNode implements ILoadable, ICodeNode {
} }
argsList = new ArrayList<>(args.size()); argsList = new ArrayList<>(args.size());
for (ArgType arg : args) { for (ArgType arg : args) {
TypeImmutableArg regArg = InsnArg.typeImmutableReg(pos, arg); RegisterArg regArg = InsnArg.reg(pos, arg);
regArg.add(AFlag.METHOD_ARGUMENT); regArg.add(AFlag.METHOD_ARGUMENT);
regArg.add(AFlag.IMMUTABLE_TYPE);
argsList.add(regArg); argsList.add(regArg);
pos += arg.getRegCount(); pos += arg.getRegCount();
} }
......
...@@ -78,8 +78,9 @@ public class InitCodeVariables extends AbstractVisitor { ...@@ -78,8 +78,9 @@ public class InitCodeVariables extends AbstractVisitor {
private static void setCodeVarType(CodeVar codeVar, Set<SSAVar> vars) { private static void setCodeVarType(CodeVar codeVar, Set<SSAVar> vars) {
if (vars.size() > 1) { if (vars.size() > 1) {
List<ArgType> imTypes = vars.stream() List<ArgType> imTypes = vars.stream()
.filter(var -> var.contains(AFlag.METHOD_ARGUMENT)) .filter(var -> var.contains(AFlag.IMMUTABLE_TYPE))
.map(var -> var.getTypeInfo().getType()) .map(var -> var.getTypeInfo().getType())
.filter(ArgType::isTypeKnown)
.distinct() .distinct()
.collect(Collectors.toList()); .collect(Collectors.toList());
int imCount = imTypes.size(); int imCount = imTypes.size();
......
...@@ -299,9 +299,6 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -299,9 +299,6 @@ public class SimplifyVisitor extends AbstractVisitor {
} }
} }
FieldArg fArg = new FieldArg(field, reg); FieldArg fArg = new FieldArg(field, reg);
if (reg != null) {
fArg.setType(get.getArg(0).getType());
}
if (wrapType == InsnType.ARITH) { if (wrapType == InsnType.ARITH) {
ArithNode ar = (ArithNode) wrap; ArithNode ar = (ArithNode) wrap;
return new ArithNode(ar.getOp(), fArg, ar.getArg(1)); return new ArithNode(ar.getOp(), fArg, ar.getArg(1));
......
...@@ -8,6 +8,7 @@ import java.util.Iterator; ...@@ -8,6 +8,7 @@ import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -16,6 +17,7 @@ import jadx.core.dex.attributes.AType; ...@@ -16,6 +17,7 @@ import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.LoopInfo; import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.Edge; import jadx.core.dex.nodes.Edge;
...@@ -180,7 +182,43 @@ public class BlockProcessor extends AbstractVisitor { ...@@ -180,7 +182,43 @@ public class BlockProcessor extends AbstractVisitor {
} }
private static boolean isSame(InsnNode insn, InsnNode curInsn) { private static boolean isSame(InsnNode insn, InsnNode curInsn) {
return /*insn.getType() == InsnType.MOVE &&*/ insn.isDeepEquals(curInsn) && insn.canReorder(); return isInsnsEquals(insn, curInsn) && insn.canReorder();
}
private static boolean isInsnsEquals(InsnNode insn, InsnNode otherInsn) {
if (insn == otherInsn) {
return true;
}
if (insn.isSame(otherInsn)
&& sameArgs(insn.getResult(), otherInsn.getResult())) {
int argsCount = insn.getArgsCount();
for (int i = 0; i < argsCount; i++) {
if (!sameArgs(insn.getArg(i), otherInsn.getArg(i))) {
return false;
}
}
return true;
}
return false;
}
private static boolean sameArgs(@Nullable InsnArg arg, @Nullable InsnArg otherArg) {
if (arg == otherArg) {
return true;
}
if (arg == null || otherArg == null) {
return false;
}
if (arg.getClass().equals(otherArg.getClass())) {
if (arg.isRegister()) {
return ((RegisterArg) arg).getRegNum() == ((RegisterArg) otherArg).getRegNum();
}
if (arg.isLiteral()) {
return ((LiteralArg) arg).getLiteral() == ((LiteralArg) otherArg).getLiteral();
}
throw new JadxRuntimeException("Unexpected InsnArg types: " + arg + " and " + otherArg);
}
return false;
} }
private static InsnNode getInsnsFromEnd(BlockNode block, int number) { private static InsnNode getInsnsFromEnd(BlockNode block, int number) {
......
...@@ -15,6 +15,7 @@ import org.slf4j.LoggerFactory; ...@@ -15,6 +15,7 @@ import org.slf4j.LoggerFactory;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.DeclareVariablesAttr; import jadx.core.dex.attributes.nodes.DeclareVariablesAttr;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.CodeVar; import jadx.core.dex.instructions.args.CodeVar;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.args.SSAVar;
...@@ -42,6 +43,7 @@ public class ProcessVariables extends AbstractVisitor { ...@@ -42,6 +43,7 @@ public class ProcessVariables extends AbstractVisitor {
if (codeVars.isEmpty()) { if (codeVars.isEmpty()) {
return; return;
} }
checkCodeVars(mth, codeVars);
// TODO: reduce code vars by name if debug info applied. Need checks for variable scopes before reduce // TODO: reduce code vars by name if debug info applied. Need checks for variable scopes before reduce
// collect all variables usage // collect all variables usage
...@@ -59,6 +61,29 @@ public class ProcessVariables extends AbstractVisitor { ...@@ -59,6 +61,29 @@ public class ProcessVariables extends AbstractVisitor {
} }
} }
private void checkCodeVars(MethodNode mth, List<CodeVar> codeVars) {
int unknownTypesCount = 0;
for (CodeVar codeVar : codeVars) {
codeVar.getSsaVars().stream()
.filter(ssaVar -> ssaVar.contains(AFlag.IMMUTABLE_TYPE))
.forEach(ssaVar -> {
ArgType ssaType = ssaVar.getAssign().getInitType();
if (ssaType.isTypeKnown() && !ssaType.equals(codeVar.getType())) {
mth.addWarn("Incorrect type for immutable var: ssa=" + ssaType
+ ", code=" + codeVar.getType()
+ ", for " + ssaVar.getDetailedVarInfo(mth));
}
});
if (codeVar.getType() == null) {
codeVar.setType(ArgType.UNKNOWN);
unknownTypesCount++;
}
}
if (unknownTypesCount != 0) {
mth.addWarn("Unknown variable types count: " + unknownTypesCount);
}
}
private void declareVar(MethodNode mth, CodeVar codeVar, List<VarUsage> usageList) { private void declareVar(MethodNode mth, CodeVar codeVar, List<VarUsage> usageList) {
if (codeVar.isDeclared()) { if (codeVar.isDeclared()) {
return; return;
......
...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors.ssa; ...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors.ssa;
import java.util.Arrays; import java.util.Arrays;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
...@@ -22,7 +23,8 @@ final class RenameState { ...@@ -22,7 +23,8 @@ final class RenameState {
new int[regsCount] new int[regsCount]
); );
for (RegisterArg arg : mth.getArguments(true)) { for (RegisterArg arg : mth.getArguments(true)) {
state.startVar(arg); SSAVar ssaVar = state.startVar(arg);
ssaVar.add(AFlag.METHOD_ARGUMENT);
} }
return state; return state;
} }
...@@ -51,9 +53,11 @@ final class RenameState { ...@@ -51,9 +53,11 @@ final class RenameState {
return vars[regNum]; return vars[regNum];
} }
public void startVar(RegisterArg regArg) { public SSAVar startVar(RegisterArg regArg) {
int regNum = regArg.getRegNum(); int regNum = regArg.getRegNum();
int version = versions[regNum]++; int version = versions[regNum]++;
vars[regNum] = mth.makeNewSVar(regNum, version, regArg); SSAVar ssaVar = mth.makeNewSVar(regNum, version, regArg);
vars[regNum] = ssaVar;
return ssaVar;
} }
} }
...@@ -64,7 +64,9 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -64,7 +64,9 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
// collect initial type bounds from assign and usages // collect initial type bounds from assign and usages
mth.getSVars().forEach(this::attachBounds); mth.getSVars().forEach(this::attachBounds);
mth.getSVars().forEach(this::mergePhiBounds); mth.getSVars().forEach(this::mergePhiBounds);
// start initial type propagation, check types from bounds
// start initial type propagation
mth.getSVars().forEach(this::setImmutableType);
mth.getSVars().forEach(this::setBestType); mth.getSVars().forEach(this::setBestType);
// try other types if type is still unknown // try other types if type is still unknown
...@@ -100,7 +102,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -100,7 +102,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
+ ", time: " + time + " ms"); + ", time: " + time + " ms");
} }
private boolean setBestType(SSAVar ssaVar) { private boolean setImmutableType(SSAVar ssaVar) {
try { try {
ArgType codeVarType = ssaVar.getCodeVar().getType(); ArgType codeVarType = ssaVar.getCodeVar().getType();
if (codeVarType != null) { if (codeVarType != null) {
...@@ -110,9 +112,25 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -110,9 +112,25 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
if (assignArg.isTypeImmutable()) { if (assignArg.isTypeImmutable()) {
return applyImmutableType(ssaVar, assignArg.getInitType()); return applyImmutableType(ssaVar, assignArg.getInitType());
} }
if (ssaVar.contains(AFlag.IMMUTABLE_TYPE)) {
for (RegisterArg arg : ssaVar.getUseList()) {
if (arg.isTypeImmutable()) {
return applyImmutableType(ssaVar, arg.getInitType());
}
}
}
return false;
} catch (Exception e) {
LOG.error("Failed to set immutable type for var: {}", ssaVar, e);
return false;
}
}
private boolean setBestType(SSAVar ssaVar) {
try {
return calculateFromBounds(ssaVar); return calculateFromBounds(ssaVar);
} catch (Exception e) { } catch (Exception e) {
LOG.error("Failed to calculate best type for var: {}", ssaVar); LOG.error("Failed to calculate best type for var: {}", ssaVar, e);
return false; return false;
} }
} }
......
...@@ -3,7 +3,7 @@ package jadx.core.dex.visitors.typeinference; ...@@ -3,7 +3,7 @@ package jadx.core.dex.visitors.typeinference;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
...@@ -75,7 +75,12 @@ public class TypeSearch { ...@@ -75,7 +75,12 @@ public class TypeSearch {
private boolean applyResolvedVars() { private boolean applyResolvedVars() {
List<TypeSearchVarInfo> resolvedVars = state.getResolvedVars(); List<TypeSearchVarInfo> resolvedVars = state.getResolvedVars();
for (TypeSearchVarInfo var : resolvedVars) { for (TypeSearchVarInfo var : resolvedVars) {
var.getVar().setType(var.getCurrentType()); SSAVar ssaVar = var.getVar();
ArgType resolvedType = var.getCurrentType();
ssaVar.getAssign().setType(resolvedType);
for (RegisterArg arg : ssaVar.getUseList()) {
arg.setType(resolvedType);
}
} }
boolean applySuccess = true; boolean applySuccess = true;
for (TypeSearchVarInfo var : resolvedVars) { for (TypeSearchVarInfo var : resolvedVars) {
...@@ -199,8 +204,8 @@ public class TypeSearch { ...@@ -199,8 +204,8 @@ public class TypeSearch {
return; return;
} }
Set<ArgType> assigns = new HashSet<>(); Set<ArgType> assigns = new LinkedHashSet<>();
Set<ArgType> uses = new HashSet<>(); Set<ArgType> uses = new LinkedHashSet<>();
Set<ITypeBound> bounds = ssaVar.getTypeInfo().getBounds(); Set<ITypeBound> bounds = ssaVar.getTypeInfo().getBounds();
for (ITypeBound bound : bounds) { for (ITypeBound bound : bounds) {
if (bound.getBound() == BoundEnum.ASSIGN) { if (bound.getBound() == BoundEnum.ASSIGN) {
...@@ -210,7 +215,7 @@ public class TypeSearch { ...@@ -210,7 +215,7 @@ public class TypeSearch {
} }
} }
Set<ArgType> candidateTypes = new HashSet<>(); Set<ArgType> candidateTypes = new LinkedHashSet<>();
addCandidateTypes(bounds, candidateTypes, assigns); addCandidateTypes(bounds, candidateTypes, assigns);
addCandidateTypes(bounds, candidateTypes, uses); addCandidateTypes(bounds, candidateTypes, uses);
......
...@@ -60,6 +60,10 @@ public final class TypeUpdate { ...@@ -60,6 +60,10 @@ public final class TypeUpdate {
if (updates.isEmpty()) { if (updates.isEmpty()) {
return SAME; return SAME;
} }
if (Consts.DEBUG && LOG.isDebugEnabled()) {
LOG.debug("Applying types, init for {} -> {}", ssaVar, candidateType);
updates.forEach(updateEntry -> LOG.debug(" {} -> {}", updateEntry.getType(), updateEntry.getArg()));
}
updates.forEach(TypeUpdateEntry::apply); updates.forEach(TypeUpdateEntry::apply);
return CHANGED; return CHANGED;
} }
......
package jadx.tests.external; package jadx.tests.external;
import java.io.File; import java.io.File;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
...@@ -107,7 +109,8 @@ public abstract class BaseExternalTest extends IntegrationTest { ...@@ -107,7 +109,8 @@ public abstract class BaseExternalTest extends IntegrationTest {
} catch (Exception e) { } catch (Exception e) {
throw new JadxRuntimeException("Codegen failed", e); throw new JadxRuntimeException("Codegen failed", e);
} }
LOG.warn("\n Print class: {}, {}", classNode.getFullName(), classNode.dex()); LOG.info("----------------------------------------------------------------");
LOG.info("Print class: {}, {}", classNode.getFullName(), classNode.dex());
if (mthPattern != null) { if (mthPattern != null) {
printMethods(classNode, mthPattern); printMethods(classNode, mthPattern);
} else { } else {
...@@ -134,6 +137,9 @@ public abstract class BaseExternalTest extends IntegrationTest { ...@@ -134,6 +137,9 @@ public abstract class BaseExternalTest extends IntegrationTest {
if (code == null) { if (code == null) {
return; return;
} }
String dashLine = "======================================================================================";
Map<Integer, MethodNode> methodsMap = getMethodsMap(classNode);
String[] lines = code.split(CodeWriter.NL); String[] lines = code.split(CodeWriter.NL);
for (MethodNode mth : classNode.getMethods()) { for (MethodNode mth : classNode.getMethods()) {
if (isMthMatch(mth, mthPattern)) { if (isMthMatch(mth, mthPattern)) {
...@@ -142,8 +148,14 @@ public abstract class BaseExternalTest extends IntegrationTest { ...@@ -142,8 +148,14 @@ public abstract class BaseExternalTest extends IntegrationTest {
int startLine = getCommentLinesCount(lines, decompiledLine); int startLine = getCommentLinesCount(lines, decompiledLine);
int brackets = 0; int brackets = 0;
for (int i = startLine; i > 0 && i < lines.length; i++) { for (int i = startLine; i > 0 && i < lines.length; i++) {
// stop if next method started
MethodNode mthAtLine = methodsMap.get(i);
if (mthAtLine != null && !mthAtLine.equals(mth)) {
break;
}
String line = lines[i]; String line = lines[i];
mthCode.append(line).append(CodeWriter.NL); mthCode.append(line).append(CodeWriter.NL);
// also count brackets for detect method end
if (i >= decompiledLine) { if (i >= decompiledLine) {
brackets += StringUtils.countMatches(line, '{'); brackets += StringUtils.countMatches(line, '{');
brackets -= StringUtils.countMatches(line, '}'); brackets -= StringUtils.countMatches(line, '}');
...@@ -152,11 +164,23 @@ public abstract class BaseExternalTest extends IntegrationTest { ...@@ -152,11 +164,23 @@ public abstract class BaseExternalTest extends IntegrationTest {
} }
} }
} }
LOG.info("Print method: {}\n{}", mth.getMethodInfo().getShortId(), mthCode); LOG.info("Print method: {}\n{}\n{}\n{}", mth.getMethodInfo().getShortId(),
dashLine,
mthCode,
dashLine
);
} }
} }
} }
public Map<Integer, MethodNode> getMethodsMap(ClassNode classNode) {
Map<Integer, MethodNode> linesMap = new HashMap<>();
for (MethodNode method : classNode.getMethods()) {
linesMap.put(method.getDecompiledLine() - 1, method);
}
return linesMap;
}
protected int getCommentLinesCount(String[] lines, int line) { protected int getCommentLinesCount(String[] lines, int line) {
for (int i = line - 1; i > 0 && i < lines.length; i--) { for (int i = line - 1; i > 0 && i < lines.length; i--) {
String str = lines[i]; String str = lines[i];
......
package jadx.tests.integration.conditions;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
public class TestConditionInLoop extends IntegrationTest {
public static class TestCls {
private static int test(int a, int b) {
int c = a + b;
for (int i = a; i < b; i++) {
if (i == 7) {
c += 2;
} else {
c *= 2;
}
}
c--;
return c;
}
public void check() {
assertThat(test(5, 9), is(115));
assertThat(test(8, 23), is(1015807));
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("for (int i = a; i < b; i++) {"));
assertThat(code, containsOne("c += 2;"));
assertThat(code, containsOne("c *= 2;"));
}
@Test
public void testNoDebug() {
noDebugInfo();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("while"));
}
}
...@@ -56,13 +56,13 @@ public class TestTryCatchFinally6 extends IntegrationTest { ...@@ -56,13 +56,13 @@ public class TestTryCatchFinally6 extends IntegrationTest {
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, containsLines(2, assertThat(code, containsLines(2,
"InputStream inputStream = null;", "FileInputStream fileInputStream = null;",
"try {", "try {",
indent() + "call();", indent() + "call();",
indent() + "inputStream = new FileInputStream(\"1.txt\");", indent() + "fileInputStream = new FileInputStream(\"1.txt\");",
"} finally {", "} finally {",
indent() + "if (inputStream != null) {", indent() + "if (fileInputStream != null) {",
indent() + indent() + "inputStream.close();", indent() + indent() + "fileInputStream.close();",
indent() + "}", indent() + "}",
"}" "}"
)); ));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment