Commit 0968f75e authored by Skylot's avatar Skylot

core: fix condition in loops (issue #9)

parent bc0db88a
...@@ -10,6 +10,7 @@ public final class IfInfo { ...@@ -10,6 +10,7 @@ public final class IfInfo {
private final Set<BlockNode> mergedBlocks = new HashSet<BlockNode>(); private final Set<BlockNode> mergedBlocks = new HashSet<BlockNode>();
private final BlockNode thenBlock; private final BlockNode thenBlock;
private final BlockNode elseBlock; private final BlockNode elseBlock;
private BlockNode outBlock;
@Deprecated @Deprecated
private BlockNode ifBlock; private BlockNode ifBlock;
...@@ -50,6 +51,14 @@ public final class IfInfo { ...@@ -50,6 +51,14 @@ public final class IfInfo {
return elseBlock; return elseBlock;
} }
public BlockNode getOutBlock() {
return outBlock;
}
public void setOutBlock(BlockNode outBlock) {
this.outBlock = outBlock;
}
public BlockNode getIfBlock() { public BlockNode getIfBlock() {
return ifBlock; return ifBlock;
} }
......
...@@ -125,7 +125,9 @@ public final class LoopRegion extends AbstractRegion { ...@@ -125,7 +125,9 @@ public final class LoopRegion extends AbstractRegion {
if (conditionBlock != null) { if (conditionBlock != null) {
all.add(conditionBlock); all.add(conditionBlock);
} }
if (body != null) {
all.add(body); all.add(body);
}
return Collections.unmodifiableList(all); return Collections.unmodifiableList(all);
} }
......
...@@ -331,7 +331,7 @@ public class BlockMakerVisitor extends AbstractVisitor { ...@@ -331,7 +331,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
private static void markReturnBlocks(MethodNode mth) { private static void markReturnBlocks(MethodNode mth) {
mth.getExitBlocks().clear(); mth.getExitBlocks().clear();
for (BlockNode block : mth.getBasicBlocks()) { for (BlockNode block : mth.getBasicBlocks()) {
if (BlockUtils.lastInsnType(block, InsnType.RETURN)) { if (BlockUtils.checkLastInsnType(block, InsnType.RETURN)) {
block.add(AFlag.RETURN); block.add(AFlag.RETURN);
mth.getExitBlocks().add(block); mth.getExitBlocks().add(block);
} }
...@@ -399,7 +399,7 @@ public class BlockMakerVisitor extends AbstractVisitor { ...@@ -399,7 +399,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
if (loops.size() == 1) { if (loops.size() == 1) {
LoopInfo loop = loops.get(0); LoopInfo loop = loops.get(0);
List<Edge> edges = loop.getExitEdges(); List<Edge> edges = loop.getExitEdges();
if (edges.size() > 1) { if (!edges.isEmpty()) {
boolean change = false; boolean change = false;
for (Edge edge : edges) { for (Edge edge : edges) {
BlockNode target = edge.getTarget(); BlockNode target = edge.getTarget();
...@@ -414,10 +414,7 @@ public class BlockMakerVisitor extends AbstractVisitor { ...@@ -414,10 +414,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
} }
} }
} }
if (splitReturn(mth)) { return splitReturn(mth);
return true;
}
return false;
} }
private static BlockNode insertBlockBetween(MethodNode mth, BlockNode source, BlockNode target) { private static BlockNode insertBlockBetween(MethodNode mth, BlockNode source, BlockNode target) {
...@@ -439,7 +436,6 @@ public class BlockMakerVisitor extends AbstractVisitor { ...@@ -439,7 +436,6 @@ public class BlockMakerVisitor extends AbstractVisitor {
BlockNode exitBlock = mth.getExitBlocks().get(0); BlockNode exitBlock = mth.getExitBlocks().get(0);
if (exitBlock.getPredecessors().size() > 1 if (exitBlock.getPredecessors().size() > 1
&& exitBlock.getInstructions().size() == 1 && exitBlock.getInstructions().size() == 1
&& !exitBlock.getInstructions().get(0).contains(AType.CATCH_BLOCK)
&& !exitBlock.contains(AFlag.SYNTHETIC)) { && !exitBlock.contains(AFlag.SYNTHETIC)) {
InsnNode returnInsn = exitBlock.getInstructions().get(0); InsnNode returnInsn = exitBlock.getInstructions().get(0);
List<BlockNode> preds = new ArrayList<BlockNode>(exitBlock.getPredecessors()); List<BlockNode> preds = new ArrayList<BlockNode>(exitBlock.getPredecessors());
......
...@@ -8,29 +8,99 @@ import jadx.core.dex.instructions.args.InsnArg; ...@@ -8,29 +8,99 @@ import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfCondition; import jadx.core.dex.regions.IfCondition;
import jadx.core.dex.regions.IfCondition.Mode; import jadx.core.dex.regions.IfCondition.Mode;
import jadx.core.dex.regions.IfInfo; import jadx.core.dex.regions.IfInfo;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Set;
import static jadx.core.utils.BlockUtils.isPathExists; import static jadx.core.utils.BlockUtils.isPathExists;
public class IfMakerHelper { public class IfMakerHelper {
private IfMakerHelper() { private IfMakerHelper() {
} }
static IfInfo makeIfInfo(BlockNode ifBlock) { static IfInfo makeIfInfo(BlockNode ifBlock) {
return makeIfInfo(ifBlock, IfCondition.fromIfBlock(ifBlock)); IfNode ifNode = (IfNode) ifBlock.getInstructions().get(0);
IfCondition condition = IfCondition.fromIfNode(ifNode);
IfInfo info = new IfInfo(condition, ifNode.getThenBlock(), ifNode.getElseBlock());
info.setIfBlock(ifBlock);
info.getMergedBlocks().add(ifBlock);
return info;
} }
static IfInfo mergeNestedIfNodes(BlockNode block) { static IfInfo restructureIf(MethodNode mth, BlockNode block, IfInfo info) {
IfInfo info = makeIfInfo(block); final BlockNode thenBlock = info.getThenBlock();
return mergeNestedIfNodes(info); final BlockNode elseBlock = info.getElseBlock();
// select 'then', 'else' and 'exit' blocks
if (thenBlock.contains(AFlag.RETURN) && elseBlock.contains(AFlag.RETURN)) {
info.setOutBlock(null);
return info;
}
boolean badThen = !allPathsFromIf(thenBlock, info);
boolean badElse = !allPathsFromIf(elseBlock, info);
if (badThen && badElse) {
return null;
}
if (badThen || badElse) {
if (badElse && isPathExists(thenBlock, elseBlock)) {
info = new IfInfo(info.getCondition(), thenBlock, null);
info.setOutBlock(elseBlock);
} else if (badThen && isPathExists(elseBlock, thenBlock)) {
info = IfInfo.invert(info);
info = new IfInfo(info.getCondition(), info.getThenBlock(), null);
info.setOutBlock(thenBlock);
} else if (badElse) {
info = new IfInfo(info.getCondition(), thenBlock, null);
info.setOutBlock(null);
} else {
info = IfInfo.invert(info);
info = new IfInfo(info.getCondition(), info.getThenBlock(), null);
info.setOutBlock(null);
}
} else {
List<BlockNode> thenSC = thenBlock.getCleanSuccessors();
List<BlockNode> elseSC = elseBlock.getCleanSuccessors();
if (thenSC.size() == 1 && sameElements(thenSC, elseSC)) {
info.setOutBlock(thenSC.get(0));
} else if (info.getMergedBlocks().size() == 1
&& block.getDominatesOn().size() == 2) {
info.setOutBlock(BlockUtils.getPathCross(mth, thenBlock, elseBlock));
}
}
if (info.getOutBlock() == null) {
for (BlockNode d : block.getDominatesOn()) {
if (d != thenBlock && d != elseBlock
&& !info.getMergedBlocks().contains(d)
&& isPathExists(thenBlock, d)) {
info.setOutBlock(d);
break;
}
}
}
if (BlockUtils.isBackEdge(block, info.getOutBlock())) {
info.setOutBlock(null);
}
return info;
} }
private static IfInfo mergeNestedIfNodes(IfInfo currentIf) { private static boolean allPathsFromIf(BlockNode block, IfInfo info) {
List<BlockNode> preds = block.getPredecessors();
Set<BlockNode> ifBlocks = info.getMergedBlocks();
return ifBlocks.containsAll(preds);
}
private static boolean sameElements(Collection<BlockNode> c1, Collection<BlockNode> c2) {
return c1.size() == c2.size() && c1.containsAll(c2);
}
static IfInfo mergeNestedIfNodes(IfInfo currentIf) {
BlockNode curThen = currentIf.getThenBlock(); BlockNode curThen = currentIf.getThenBlock();
BlockNode curElse = currentIf.getElseBlock(); BlockNode curElse = currentIf.getElseBlock();
if (curThen == curElse) { if (curThen == curElse) {
...@@ -93,14 +163,6 @@ public class IfMakerHelper { ...@@ -93,14 +163,6 @@ public class IfMakerHelper {
|| RegionMaker.isEqualPaths(currentIf.getThenBlock(), nextIf.getElseBlock()); || RegionMaker.isEqualPaths(currentIf.getThenBlock(), nextIf.getElseBlock());
} }
private static IfInfo makeIfInfo(BlockNode ifBlock, IfCondition condition) {
IfNode ifnode = (IfNode) ifBlock.getInstructions().get(0);
IfInfo info = new IfInfo(condition, ifnode.getThenBlock(), ifnode.getElseBlock());
info.setIfBlock(ifBlock);
info.getMergedBlocks().add(ifBlock);
return info;
}
private static boolean checkConditionBranches(BlockNode from, BlockNode to) { private static boolean checkConditionBranches(BlockNode from, BlockNode to) {
return from.getCleanSuccessors().size() == 1 && from.getCleanSuccessors().contains(to); return from.getCleanSuccessors().size() == 1 && from.getCleanSuccessors().contains(to);
} }
......
...@@ -51,6 +51,7 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor, ...@@ -51,6 +51,7 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
private static void processIfRegion(MethodNode mth, IfRegion ifRegion) { private static void processIfRegion(MethodNode mth, IfRegion ifRegion) {
simplifyIfCondition(ifRegion); simplifyIfCondition(ifRegion);
moveReturnToThenBlock(mth, ifRegion); moveReturnToThenBlock(mth, ifRegion);
moveBreakToThenBlock(ifRegion);
markElseIfChains(ifRegion); markElseIfChains(ifRegion);
TernaryMod.makeTernaryInsn(mth, ifRegion); TernaryMod.makeTernaryInsn(mth, ifRegion);
...@@ -103,6 +104,13 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor, ...@@ -103,6 +104,13 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
} }
} }
private static void moveBreakToThenBlock(IfRegion ifRegion) {
if (ifRegion.getElseRegion() != null
&& RegionUtils.hasBreakInsn(ifRegion.getElseRegion())) {
invertIfRegion(ifRegion);
}
}
/** /**
* Mark if-else-if chains * Mark if-else-if chains
*/ */
...@@ -124,7 +132,7 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor, ...@@ -124,7 +132,7 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
if (ifRegion.getElseRegion() != null if (ifRegion.getElseRegion() != null
&& !ifRegion.contains(AFlag.ELSE_IF_CHAIN) && !ifRegion.contains(AFlag.ELSE_IF_CHAIN)
&& !ifRegion.getElseRegion().contains(AFlag.ELSE_IF_CHAIN) && !ifRegion.getElseRegion().contains(AFlag.ELSE_IF_CHAIN)
&& RegionUtils.hasExitBlock(ifRegion.getThenRegion()) && hasBranchTerminator(ifRegion)
&& insnsCount(ifRegion.getThenRegion()) < 2) { && insnsCount(ifRegion.getThenRegion()) < 2) {
IRegion parent = ifRegion.getParent(); IRegion parent = ifRegion.getParent();
Region newRegion = new Region(parent); Region newRegion = new Region(parent);
...@@ -138,6 +146,12 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor, ...@@ -138,6 +146,12 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
return false; return false;
} }
private static boolean hasBranchTerminator(IfRegion ifRegion) {
// TODO: check for exception throw
return RegionUtils.hasExitBlock(ifRegion.getThenRegion())
|| RegionUtils.hasBreakInsn(ifRegion.getThenRegion());
}
private static void invertIfRegion(IfRegion ifRegion) { private static void invertIfRegion(IfRegion ifRegion) {
IContainer elseRegion = ifRegion.getElseRegion(); IContainer elseRegion = ifRegion.getElseRegion();
if (elseRegion != null) { if (elseRegion != null) {
......
...@@ -9,7 +9,9 @@ import jadx.core.utils.exceptions.JadxRuntimeException; ...@@ -9,7 +9,9 @@ import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet; import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
...@@ -101,13 +103,13 @@ public class BlockUtils { ...@@ -101,13 +103,13 @@ public class BlockUtils {
return false; return false;
} }
public static boolean lastInsnType(BlockNode block, InsnType type) { public static boolean checkLastInsnType(BlockNode block, InsnType expectedType) {
List<InsnNode> insns = block.getInstructions(); List<InsnNode> insns = block.getInstructions();
if (insns.isEmpty()) { if (insns.isEmpty()) {
return false; return false;
} }
InsnNode insn = insns.get(insns.size() - 1); InsnNode insn = insns.get(insns.size() - 1);
return insn.getType() == type; return insn.getType() == expectedType;
} }
public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) { public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) {
...@@ -288,4 +290,15 @@ public class BlockUtils { ...@@ -288,4 +290,15 @@ public class BlockUtils {
} }
} }
} }
public static List<BlockNode> buildSimplePath(BlockNode block) {
List<BlockNode> list = new LinkedList<BlockNode>();
while (block != null
&& block.getCleanSuccessors().size() < 2
&& block.getPredecessors().size() == 1) {
list.add(block);
block = getNextBlock(block);
}
return list.isEmpty() ? Collections.<BlockNode>emptyList() : list;
}
} }
...@@ -2,6 +2,7 @@ package jadx.core.utils; ...@@ -2,6 +2,7 @@ package jadx.core.utils;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.AType;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IContainer; import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.IRegion;
...@@ -49,6 +50,18 @@ public class RegionUtils { ...@@ -49,6 +50,18 @@ public class RegionUtils {
} }
} }
public static boolean hasBreakInsn(IContainer container) {
if (container instanceof BlockNode) {
return BlockUtils.checkLastInsnType((BlockNode) container, InsnType.BREAK);
} else if (container instanceof IRegion) {
List<IContainer> blocks = ((IRegion) container).getSubBlocks();
return !blocks.isEmpty()
&& hasBreakInsn(blocks.get(blocks.size() - 1));
} else {
throw new JadxRuntimeException("Unknown container type: " + container);
}
}
public static int insnsCount(IContainer container) { public static int insnsCount(IContainer container) {
if (container instanceof BlockNode) { if (container instanceof BlockNode) {
return ((BlockNode) container).getInstructions().size(); return ((BlockNode) container).getInstructions().size();
......
...@@ -5,8 +5,7 @@ import jadx.core.dex.nodes.ClassNode; ...@@ -5,8 +5,7 @@ import jadx.core.dex.nodes.ClassNode;
import org.junit.Test; import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString; import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
public class TestBreakInLoop extends InternalJadxTest { public class TestBreakInLoop extends InternalJadxTest {
...@@ -33,8 +32,12 @@ public class TestBreakInLoop extends InternalJadxTest { ...@@ -33,8 +32,12 @@ public class TestBreakInLoop extends InternalJadxTest {
String code = cls.getCode().toString(); String code = cls.getCode().toString();
System.out.println(code); System.out.println(code);
assertEquals(1, count(code, "this.f++;")); assertThat(code, containsOne("this.f++;"));
assertThat(code, containsString("if (i < b) {")); // assertThat(code, containsOne("a[i]++;"));
assertThat(code, containsString("break;")); assertThat(code, containsOne("if (i < b) {"));
assertThat(code, containsOne("break;"));
assertThat(code, containsOne("i++;"));
// assertThat(code, countString(0, "else"));
} }
} }
...@@ -5,7 +5,7 @@ import jadx.core.dex.nodes.ClassNode; ...@@ -5,7 +5,7 @@ import jadx.core.dex.nodes.ClassNode;
import org.junit.Test; import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString; import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
public class TestLoopCondition extends InternalJadxTest { public class TestLoopCondition extends InternalJadxTest {
...@@ -48,8 +48,12 @@ public class TestLoopCondition extends InternalJadxTest { ...@@ -48,8 +48,12 @@ public class TestLoopCondition extends InternalJadxTest {
String code = cls.getCode().toString(); String code = cls.getCode().toString();
System.out.println(code); System.out.println(code);
assertThat(code, containsString("i < this.f.length()")); assertThat(code, containsOne("i < this.f.length()"));
assertThat(code, containsString("list.set(i, \"ABC\")")); assertThat(code, containsOne("list.set(i, \"ABC\")"));
assertThat(code, containsString("list.set(i, \"DEF\")")); assertThat(code, containsOne("list.set(i, \"DEF\")"));
assertThat(code, containsOne("if (j == 2) {"));
assertThat(code, containsOne("setEnabled(true);"));
assertThat(code, containsOne("setEnabled(false);"));
} }
} }
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestLoopCondition3 extends InternalJadxTest {
public static class TestCls {
public static void test(int a, int b, int c) {
while (a < 12) {
if (b + a < 9 && b < 8) {
if (b >= 2 && a > -1 && b < 6) {
System.out.println("OK");
c = b + 1;
}
b = a;
}
c = b;
b++;
b = c;
a++;
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("while (a < 12) {"));
assertThat(code, containsOne("if (b + a < 9 && b < 8) {"));
assertThat(code, containsOne("if (b >= 2 && a > -1 && b < 6) {"));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestLoopCondition4 extends InternalJadxTest {
public static class TestCls {
public static void test() {
int n = -1;
while (n < 0) {
n += 12;
}
while (n > 11) {
n -= 12;
}
System.out.println(n);
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("int n = -1;"));
assertThat(code, containsOne("while (n < 0) {"));
assertThat(code, containsOne("n += 12;"));
assertThat(code, containsOne("while (n > 11) {"));
assertThat(code, containsOne("n -= 12;"));
assertThat(code, containsOne("System.out.println(n);"));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static jadx.tests.utils.JadxMatchers.countString;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestSequentialLoops extends InternalJadxTest {
public static class TestCls {
public int test7(int a, int b) {
int c = b;
int z;
while (true) {
z = c + a;
if (z >= 7) {
break;
}
c = z;
}
while ((z = c + a) >= 7) {
c = z;
}
return c;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, countString(2, "while ("));
assertThat(code, containsOne("break;"));
assertThat(code, containsOne("return c;"));
assertThat(code, not(containsString("else")));
}
}
...@@ -18,10 +18,10 @@ public class TestTryCatch2 extends InternalJadxTest { ...@@ -18,10 +18,10 @@ public class TestTryCatch2 extends InternalJadxTest {
synchronized (obj) { synchronized (obj) {
obj.wait(5); obj.wait(5);
} }
return true;
} catch (InterruptedException e) { } catch (InterruptedException e) {
return false; return false;
} }
return true;
} }
} }
...@@ -34,12 +34,8 @@ public class TestTryCatch2 extends InternalJadxTest { ...@@ -34,12 +34,8 @@ public class TestTryCatch2 extends InternalJadxTest {
assertThat(code, containsString("try {")); assertThat(code, containsString("try {"));
assertThat(code, containsString("synchronized (obj) {")); assertThat(code, containsString("synchronized (obj) {"));
assertThat(code, containsString("obj.wait(5);")); assertThat(code, containsString("obj.wait(5);"));
assertThat(code, containsString("return true;"));
assertThat(code, containsString("} catch (InterruptedException e) {")); assertThat(code, containsString("} catch (InterruptedException e) {"));
assertThat(code, containsString("return false;"));
// TODO
assertThat(code, containsString(" = false;"));
assertThat(code, containsString(" = true;"));
// assertThat(code, containsString("return false;"));
// assertThat(code, containsString("return true;"));
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment