Commit 533b686e authored by Skylot's avatar Skylot

fix: comment out instructions also before other constructor call (#685)

parent c6c54f90
...@@ -131,12 +131,20 @@ public class RegionGen extends InsnGen { ...@@ -131,12 +131,20 @@ public class RegionGen extends InsnGen {
} }
} }
} }
boolean comment = region.contains(AFlag.COMMENT_OUT);
if (comment) {
code.add("// ");
}
code.add("if ("); code.add("if (");
new ConditionGen(this).add(code, region.getCondition()); new ConditionGen(this).add(code, region.getCondition());
code.add(") {"); code.add(") {");
makeRegionIndent(code, region.getThenRegion()); makeRegionIndent(code, region.getThenRegion());
code.startLine('}'); if (comment) {
code.startLine("// }");
} else {
code.startLine('}');
}
IContainer els = region.getElseRegion(); IContainer els = region.getElseRegion();
if (RegionUtils.notEmpty(els)) { if (RegionUtils.notEmpty(els)) {
...@@ -146,7 +154,11 @@ public class RegionGen extends InsnGen { ...@@ -146,7 +154,11 @@ public class RegionGen extends InsnGen {
} }
code.add('{'); code.add('{');
makeRegionIndent(code, els); makeRegionIndent(code, els);
code.startLine('}'); if (comment) {
code.startLine("// }");
} else {
code.startLine('}');
}
} }
} }
......
...@@ -2,13 +2,11 @@ package jadx.core.dex.visitors; ...@@ -2,13 +2,11 @@ package jadx.core.dex.visitors;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Objects;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.ArithNode; import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp; import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg; import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
...@@ -16,6 +14,7 @@ import jadx.core.dex.instructions.mods.ConstructorInsn; ...@@ -16,6 +14,7 @@ import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn; import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock; import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.IfCondition; import jadx.core.dex.regions.conditions.IfCondition;
...@@ -38,8 +37,6 @@ import jadx.core.utils.exceptions.JadxException; ...@@ -38,8 +37,6 @@ import jadx.core.utils.exceptions.JadxException;
) )
public class PrepareForCodeGen extends AbstractVisitor { public class PrepareForCodeGen extends AbstractVisitor {
public static final SuperCallRegionVisitor SUPER_CALL_REGION_VISITOR = new SuperCallRegionVisitor();
@Override @Override
public void visit(MethodNode mth) throws JadxException { public void visit(MethodNode mth) throws JadxException {
List<BlockNode> blocks = mth.getBasicBlocks(); List<BlockNode> blocks = mth.getBasicBlocks();
...@@ -55,7 +52,7 @@ public class PrepareForCodeGen extends AbstractVisitor { ...@@ -55,7 +52,7 @@ public class PrepareForCodeGen extends AbstractVisitor {
removeParenthesis(block); removeParenthesis(block);
modifyArith(block); modifyArith(block);
} }
commentOutInsnsBeforeSuper(mth); commentOutInsnsInConstructor(mth);
} }
private static void removeInstructions(BlockNode block) { private static void removeInstructions(BlockNode block) {
...@@ -180,31 +177,94 @@ public class PrepareForCodeGen extends AbstractVisitor { ...@@ -180,31 +177,94 @@ public class PrepareForCodeGen extends AbstractVisitor {
} }
} }
private void commentOutInsnsBeforeSuper(MethodNode mth) { private void commentOutInsnsInConstructor(MethodNode mth) {
if (mth.isConstructor() && !Objects.equals(mth.getParentClass().getSuperClass(), ArgType.OBJECT)) { if (mth.isConstructor()) {
DepthRegionTraversal.traverse(mth, SUPER_CALL_REGION_VISITOR); ConstructorInsn constrInsn = searchConstructorCall(mth);
if (constrInsn != null && !constrInsn.contains(AFlag.DONT_GENERATE)) {
DepthRegionTraversal.traverse(mth, new ConstructorRegionVisitor(constrInsn));
}
} }
} }
private static final class SuperCallRegionVisitor extends AbstractRegionVisitor { private ConstructorInsn searchConstructorCall(MethodNode mth) {
@Override for (BlockNode block : mth.getBasicBlocks()) {
public void processBlock(MethodNode mth, IBlock container) { for (InsnNode insn : block.getInstructions()) {
for (InsnNode insn : container.getInstructions()) {
InsnType insnType = insn.getType(); InsnType insnType = insn.getType();
if ((insnType == InsnType.CONSTRUCTOR) && ((ConstructorInsn) insn).isSuper()) { if (insnType == InsnType.CONSTRUCTOR) {
// found super call ConstructorInsn constrInsn = (ConstructorInsn) insn;
commentOutInsns(container, insn); if (constrInsn.isSuper() || constrInsn.isThis()) {
// TODO: process all previous regions (in case of branching before super call) return constrInsn;
}
} }
} }
} }
return null;
}
private static final class ConstructorRegionVisitor extends AbstractRegionVisitor {
private final ConstructorInsn constrInsn;
private int regionDepth;
private boolean found;
private boolean brokenCode;
private int commentedCount;
public ConstructorRegionVisitor(ConstructorInsn constrInsn) {
this.constrInsn = constrInsn;
}
@Override
public boolean enterRegion(MethodNode mth, IRegion region) {
if (found) {
return false;
}
regionDepth++;
return true;
}
private static void commentOutInsns(IBlock container, InsnNode superInsn) { @Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (!found) {
regionDepth--;
region.add(AFlag.COMMENT_OUT);
commentedCount++;
}
}
@Override
public void processBlock(MethodNode mth, IBlock container) {
if (found) {
return;
}
for (InsnNode insn : container.getInstructions()) { for (InsnNode insn : container.getInstructions()) {
if (insn == superInsn) { if (insn == constrInsn) {
found = true;
addMethodMsg(mth);
break; break;
} }
insn.add(AFlag.COMMENT_OUT); insn.add(AFlag.COMMENT_OUT);
commentedCount++;
if (!brokenCode) {
RegisterArg resArg = insn.getResult();
if (resArg != null) {
for (RegisterArg arg : resArg.getSVar().getUseList()) {
if (arg.getParentInsn() == constrInsn) {
brokenCode = true;
break;
}
}
}
}
}
}
private void addMethodMsg(MethodNode mth) {
if (commentedCount > 0) {
String msg = "JADX WARN: Illegal instructions before constructor call commented (this can break semantics)";
if (brokenCode || regionDepth > 1) {
mth.addWarn(msg);
} else {
mth.addComment(msg);
}
} }
} }
} }
......
...@@ -81,6 +81,8 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -81,6 +81,8 @@ public abstract class IntegrationTest extends TestUtils {
protected boolean useEclipseCompiler; protected boolean useEclipseCompiler;
protected Map<Integer, String> resMap = Collections.emptyMap(); protected Map<Integer, String> resMap = Collections.emptyMap();
private boolean allowWarnInCode;
private DynamicCompiler dynamicCompiler; private DynamicCompiler dynamicCompiler;
@BeforeEach @BeforeEach
...@@ -170,7 +172,7 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -170,7 +172,7 @@ public abstract class IntegrationTest extends TestUtils {
} }
System.out.println("-----------------------------------------------------------"); System.out.println("-----------------------------------------------------------");
clsList.forEach(IntegrationTest::checkCode); clsList.forEach(this::checkCode);
compile(clsList); compile(clsList);
clsList.forEach(this::runAutoCheck); clsList.forEach(this::runAutoCheck);
} }
...@@ -212,7 +214,7 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -212,7 +214,7 @@ public abstract class IntegrationTest extends TestUtils {
} }
} }
protected static void checkCode(ClassNode cls) { protected void checkCode(ClassNode cls) {
assertFalse(hasErrors(cls), "Inconsistent cls: " + cls); assertFalse(hasErrors(cls), "Inconsistent cls: " + cls);
for (MethodNode mthNode : cls.getMethods()) { for (MethodNode mthNode : cls.getMethods()) {
assertFalse(hasErrors(mthNode), "Method with problems: " + mthNode); assertFalse(hasErrors(mthNode), "Method with problems: " + mthNode);
...@@ -220,17 +222,19 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -220,17 +222,19 @@ public abstract class IntegrationTest extends TestUtils {
assertThat(cls.getCode().toString(), not(containsString("inconsistent"))); assertThat(cls.getCode().toString(), not(containsString("inconsistent")));
} }
private static boolean hasErrors(IAttributeNode node) { private boolean hasErrors(IAttributeNode node) {
if (node.contains(AFlag.INCONSISTENT_CODE) if (node.contains(AFlag.INCONSISTENT_CODE)
|| node.contains(AType.JADX_ERROR) || node.contains(AType.JADX_ERROR)
|| node.contains(AType.JADX_WARN)) { || (node.contains(AType.JADX_WARN) && !allowWarnInCode)) {
return true; return true;
} }
AttrList<String> commentsAttr = node.get(AType.COMMENTS); if (!allowWarnInCode) {
if (commentsAttr != null) { AttrList<String> commentsAttr = node.get(AType.COMMENTS);
for (String comment : commentsAttr.getList()) { if (commentsAttr != null) {
if (comment.contains("JADX WARN")) { for (String comment : commentsAttr.getList()) {
return true; if (comment.contains("JADX WARN")) {
return true;
}
} }
} }
} }
...@@ -444,6 +448,10 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -444,6 +448,10 @@ public abstract class IntegrationTest extends TestUtils {
args.setDeobfuscationMaxLength(64); args.setDeobfuscationMaxLength(64);
} }
protected void allowWarnInCode() {
allowWarnInCode = true;
}
// Use only for debug purpose // Use only for debug purpose
@Deprecated @Deprecated
protected void outputCFG() { protected void outputCFG() {
......
...@@ -33,6 +33,7 @@ public class TestInsnsBeforeSuper extends SmaliTest { ...@@ -33,6 +33,7 @@ public class TestInsnsBeforeSuper extends SmaliTest {
@Test @Test
public void test() { public void test() {
allowWarnInCode();
ClassNode cls = getClassNodeFromSmaliFiles("B"); ClassNode cls = getClassNodeFromSmaliFiles("B");
String code = cls.getCode().toString(); String code = cls.getCode().toString();
......
package jadx.tests.integration.others;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat;
public class TestInsnsBeforeThis extends SmaliTest {
// @formatter:off
/*
public class A {
public A(String str) {
checkNull(str);
this(str.length());
}
public A(int i) {
}
public void checkNull(Object o) {
if (o == null) {
throw new NullPointerException();
}
}
}
*/
// @formatter:on
@Test
public void test() {
allowWarnInCode();
ClassNode cls = getClassNodeFromSmali();
String code = cls.getCode().toString();
assertThat(code, containsOne("// checkNull(str);"));
}
}
.class public Lothers/TestInsnsBeforeThis;
.super Ljava/lang/Object;
.method public constructor <init>(Ljava/lang/String;)V
.registers 3
.prologue
invoke-static {p1}, Lothers/TestInsnsBeforeThis;->checkNull(Ljava/lang/Object;)V
invoke-direct {p1}, Ljava/lang/String;->length()I
move-result v0
invoke-direct {p0, v0}, Lothers/TestInsnsBeforeThis;-><init>(I)V
return-void
.end method
.method public constructor <init>(I)V
.registers 3
.prologue
invoke-direct {p0}, Ljava/lang/Object;-><init>()V
return-void
.end method
.method public static checkNull(Ljava/lang/Object;)V
.registers 3
.prologue
if-nez p0, :cond_8
new-instance v0, Ljava/lang/NullPointerException;
invoke-direct {v0}, Ljava/lang/NullPointerException;-><init>()V
throw v0
:cond_8
return-void
.end method
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment