Commit b689efcc authored by Skylot's avatar Skylot

fix: forbid to change types for methods arguments

parent 89563b62
...@@ -6,7 +6,7 @@ import java.util.List; ...@@ -6,7 +6,7 @@ import java.util.List;
public class CodeVar { public class CodeVar {
private String name; private String name;
private ArgType type; private ArgType type; // nullable before type inference, set only for immutable types
private List<SSAVar> ssaVars = new ArrayList<>(3); private List<SSAVar> ssaVars = new ArrayList<>(3);
private boolean isFinal; private boolean isFinal;
......
...@@ -139,6 +139,9 @@ public class SSAVar extends AttrNode { ...@@ -139,6 +139,9 @@ public class SSAVar extends AttrNode {
public void setCodeVar(@NotNull CodeVar codeVar) { public void setCodeVar(@NotNull CodeVar codeVar) {
this.codeVar = codeVar; this.codeVar = codeVar;
if (codeVar.getType() != null && !typeInfo.getType().equals(codeVar.getType())) {
throw new JadxRuntimeException("Unmached types for SSA and Code variables: " + this + " and " + codeVar);
}
codeVar.addSsaVar(this); codeVar.addSsaVar(this);
} }
......
package jadx.core.dex.visitors; package jadx.core.dex.visitors;
import java.util.HashSet; import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.PhiInsn; import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.CodeVar; import jadx.core.dex.instructions.args.CodeVar;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.ssa.SSATransform; import jadx.core.dex.visitors.ssa.SSATransform;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
import jadx.core.utils.exceptions.JadxRuntimeException;
@JadxVisitor( @JadxVisitor(
name = "InitCodeVariables", name = "InitCodeVariables",
...@@ -41,7 +45,6 @@ public class InitCodeVariables extends AbstractVisitor { ...@@ -41,7 +45,6 @@ public class InitCodeVariables extends AbstractVisitor {
return; return;
} }
CodeVar codeVar = new CodeVar(); CodeVar codeVar = new CodeVar();
codeVar.setType(ssaVar.getTypeInfo().getType());
RegisterArg assignArg = ssaVar.getAssign(); RegisterArg assignArg = ssaVar.getAssign();
if (assignArg.contains(AFlag.THIS)) { if (assignArg.contains(AFlag.THIS)) {
codeVar.setName(RegisterArg.THIS_ARG_NAME); codeVar.setName(RegisterArg.THIS_ARG_NAME);
...@@ -55,17 +58,36 @@ public class InitCodeVariables extends AbstractVisitor { ...@@ -55,17 +58,36 @@ public class InitCodeVariables extends AbstractVisitor {
} }
private static void setCodeVar(SSAVar ssaVar, CodeVar codeVar) { private static void setCodeVar(SSAVar ssaVar, CodeVar codeVar) {
ssaVar.setCodeVar(codeVar);
PhiInsn usedInPhi = ssaVar.getUsedInPhi(); PhiInsn usedInPhi = ssaVar.getUsedInPhi();
if (usedInPhi != null) { if (usedInPhi != null) {
Set<SSAVar> vars = new HashSet<>(); Set<SSAVar> vars = new LinkedHashSet<>();
vars.add(ssaVar);
collectConnectedVars(usedInPhi, vars); collectConnectedVars(usedInPhi, vars);
setCodeVarType(codeVar, vars);
vars.forEach(var -> { vars.forEach(var -> {
if (var.isCodeVarSet()) { if (var.isCodeVarSet()) {
codeVar.mergeFlagsFrom(var.getCodeVar()); codeVar.mergeFlagsFrom(var.getCodeVar());
} }
var.setCodeVar(codeVar); var.setCodeVar(codeVar);
}); });
} else {
ssaVar.setCodeVar(codeVar);
}
}
private static void setCodeVarType(CodeVar codeVar, Set<SSAVar> vars) {
if (vars.size() > 1) {
List<ArgType> imTypes = vars.stream()
.filter(var -> var.contains(AFlag.METHOD_ARGUMENT))
.map(var -> var.getTypeInfo().getType())
.distinct()
.collect(Collectors.toList());
int imCount = imTypes.size();
if (imCount == 1) {
codeVar.setType(imTypes.get(0));
} else if (imCount > 1) {
throw new JadxRuntimeException("Several immutable types in one variable: " + imTypes + ", vars: " + vars);
}
} }
} }
......
...@@ -102,11 +102,22 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -102,11 +102,22 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
private boolean setBestType(SSAVar ssaVar) { private boolean setBestType(SSAVar ssaVar) {
try { try {
ArgType codeVarType = ssaVar.getCodeVar().getType();
if (codeVarType != null) {
return applyImmutableType(ssaVar, codeVarType);
}
RegisterArg assignArg = ssaVar.getAssign(); RegisterArg assignArg = ssaVar.getAssign();
if (!assignArg.isTypeImmutable()) { if (assignArg.isTypeImmutable()) {
return applyImmutableType(ssaVar, assignArg.getInitType());
}
return calculateFromBounds(ssaVar); return calculateFromBounds(ssaVar);
} catch (Exception e) {
LOG.error("Failed to calculate best type for var: {}", ssaVar);
return false;
}
} }
ArgType initType = assignArg.getInitType();
private boolean applyImmutableType(SSAVar ssaVar, ArgType initType) {
TypeUpdateResult result = typeUpdate.apply(ssaVar, initType); TypeUpdateResult result = typeUpdate.apply(ssaVar, initType);
if (result == TypeUpdateResult.REJECT) { if (result == TypeUpdateResult.REJECT) {
if (Consts.DEBUG) { if (Consts.DEBUG) {
...@@ -114,18 +125,22 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -114,18 +125,22 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
} }
return false; return false;
} }
return true; return result == TypeUpdateResult.CHANGED;
} catch (Exception e) {
LOG.error("Failed to calculate best type for var: {}", ssaVar);
return false;
}
} }
private boolean calculateFromBounds(SSAVar ssaVar) { private boolean calculateFromBounds(SSAVar ssaVar) {
TypeInfo typeInfo = ssaVar.getTypeInfo(); TypeInfo typeInfo = ssaVar.getTypeInfo();
Set<ITypeBound> bounds = typeInfo.getBounds(); Set<ITypeBound> bounds = typeInfo.getBounds();
Optional<ArgType> bestTypeOpt = selectBestTypeFromBounds(bounds); Optional<ArgType> bestTypeOpt = selectBestTypeFromBounds(bounds);
if (bestTypeOpt.isPresent()) { if (!bestTypeOpt.isPresent()) {
if (Consts.DEBUG) {
LOG.warn("Failed to select best type from bounds, count={} : ", bounds.size());
for (ITypeBound bound : bounds) {
LOG.warn(" {}", bound);
}
}
return false;
}
ArgType candidateType = bestTypeOpt.get(); ArgType candidateType = bestTypeOpt.get();
TypeUpdateResult result = typeUpdate.apply(ssaVar, candidateType); TypeUpdateResult result = typeUpdate.apply(ssaVar, candidateType);
if (result == TypeUpdateResult.REJECT) { if (result == TypeUpdateResult.REJECT) {
...@@ -140,14 +155,6 @@ public final class TypeInferenceVisitor extends AbstractVisitor { ...@@ -140,14 +155,6 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
} }
return result == TypeUpdateResult.CHANGED; return result == TypeUpdateResult.CHANGED;
} }
if (Consts.DEBUG) {
LOG.warn("Failed to select best type from bounds, count={} : ", bounds.size());
for (ITypeBound bound : bounds) {
LOG.warn(" {}", bound);
}
}
return false;
}
private Optional<ArgType> selectBestTypeFromBounds(Set<ITypeBound> bounds) { private Optional<ArgType> selectBestTypeFromBounds(Set<ITypeBound> bounds) {
return bounds.stream() return bounds.stream()
......
package jadx.core.dex.visitors.typeinference; package jadx.core.dex.visitors.typeinference;
import java.util.HashSet; import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
...@@ -10,7 +10,7 @@ import jadx.core.dex.instructions.args.ArgType; ...@@ -10,7 +10,7 @@ import jadx.core.dex.instructions.args.ArgType;
public class TypeInfo { public class TypeInfo {
private ArgType type = ArgType.UNKNOWN; private ArgType type = ArgType.UNKNOWN;
private final Set<ITypeBound> bounds = new HashSet<>(); private final Set<ITypeBound> bounds = new LinkedHashSet<>();
@NotNull @NotNull
public ArgType getType() { public ArgType getType() {
......
...@@ -48,7 +48,7 @@ public final class TypeUpdate { ...@@ -48,7 +48,7 @@ public final class TypeUpdate {
if (candidateType == null) { if (candidateType == null) {
return REJECT; return REJECT;
} }
if (!candidateType.isTypeKnown() && ssaVar.getTypeInfo().getType().isTypeKnown()) { if (!candidateType.isTypeKnown()/* && ssaVar.getTypeInfo().getType().isTypeKnown()*/) {
return REJECT; return REJECT;
} }
...@@ -86,14 +86,14 @@ public final class TypeUpdate { ...@@ -86,14 +86,14 @@ public final class TypeUpdate {
return SAME; return SAME;
} }
TypeCompareEnum compareResult = comparator.compareTypes(candidateType, currentType); TypeCompareEnum compareResult = comparator.compareTypes(candidateType, currentType);
if (arg.isTypeImmutable() && currentType != ArgType.UNKNOWN) {
// don't changed type
if (compareResult == TypeCompareEnum.CONFLICT) { if (compareResult == TypeCompareEnum.CONFLICT) {
if (Consts.DEBUG) { if (Consts.DEBUG) {
LOG.debug("Type rejected for {} due to conflict: candidate={}, current={}", arg, candidateType, currentType); LOG.debug("Type rejected for {} due to conflict: candidate={}, current={}", arg, candidateType, currentType);
} }
return REJECT; return REJECT;
} }
if (arg.isTypeImmutable() && currentType != ArgType.UNKNOWN) {
// don't changed type, conflict already rejected
return SAME; return SAME;
} }
if (compareResult.isWider()) { if (compareResult.isWider()) {
......
...@@ -56,13 +56,13 @@ public class TestTryCatchFinally6 extends IntegrationTest { ...@@ -56,13 +56,13 @@ public class TestTryCatchFinally6 extends IntegrationTest {
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, containsLines(2, assertThat(code, containsLines(2,
"FileInputStream fileInputStream = null;", "InputStream inputStream = null;",
"try {", "try {",
indent() + "call();", indent() + "call();",
indent() + "fileInputStream = new FileInputStream(\"1.txt\");", indent() + "inputStream = new FileInputStream(\"1.txt\");",
"} finally {", "} finally {",
indent() + "if (fileInputStream != null) {", indent() + "if (inputStream != null) {",
indent() + indent() + "fileInputStream.close();", indent() + indent() + "inputStream.close();",
indent() + "}", indent() + "}",
"}" "}"
)); ));
......
...@@ -24,7 +24,7 @@ public class TestVariablesGeneric extends SmaliTest { ...@@ -24,7 +24,7 @@ public class TestVariablesGeneric extends SmaliTest {
@Test @Test
public void test() { public void test() {
disableCompilation(); disableCompilation();
ClassNode cls = getClassNodeFromSmaliWithPath("variables", "TestVariablesGeneric"); ClassNode cls = getClassNodeFromSmaliWithPkg("variables", "TestVariablesGeneric");
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, not(containsString("iVar2"))); assertThat(code, not(containsString("iVar2")));
......
.class public LTestVariablesGeneric; .class public Lvariables/TestVariablesGeneric;
.super Ljava/lang/Object; .super Ljava/lang/Object;
.source "SourceFile" .source "SourceFile"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment