Commit 0968f75e authored by Skylot's avatar Skylot

core: fix condition in loops (issue #9)

parent bc0db88a
......@@ -10,6 +10,7 @@ public final class IfInfo {
private final Set<BlockNode> mergedBlocks = new HashSet<BlockNode>();
private final BlockNode thenBlock;
private final BlockNode elseBlock;
private BlockNode outBlock;
@Deprecated
private BlockNode ifBlock;
......@@ -50,6 +51,14 @@ public final class IfInfo {
return elseBlock;
}
public BlockNode getOutBlock() {
return outBlock;
}
public void setOutBlock(BlockNode outBlock) {
this.outBlock = outBlock;
}
public BlockNode getIfBlock() {
return ifBlock;
}
......
......@@ -125,7 +125,9 @@ public final class LoopRegion extends AbstractRegion {
if (conditionBlock != null) {
all.add(conditionBlock);
}
all.add(body);
if (body != null) {
all.add(body);
}
return Collections.unmodifiableList(all);
}
......
......@@ -331,7 +331,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
private static void markReturnBlocks(MethodNode mth) {
mth.getExitBlocks().clear();
for (BlockNode block : mth.getBasicBlocks()) {
if (BlockUtils.lastInsnType(block, InsnType.RETURN)) {
if (BlockUtils.checkLastInsnType(block, InsnType.RETURN)) {
block.add(AFlag.RETURN);
mth.getExitBlocks().add(block);
}
......@@ -399,7 +399,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
if (loops.size() == 1) {
LoopInfo loop = loops.get(0);
List<Edge> edges = loop.getExitEdges();
if (edges.size() > 1) {
if (!edges.isEmpty()) {
boolean change = false;
for (Edge edge : edges) {
BlockNode target = edge.getTarget();
......@@ -414,10 +414,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
}
}
}
if (splitReturn(mth)) {
return true;
}
return false;
return splitReturn(mth);
}
private static BlockNode insertBlockBetween(MethodNode mth, BlockNode source, BlockNode target) {
......@@ -439,7 +436,6 @@ public class BlockMakerVisitor extends AbstractVisitor {
BlockNode exitBlock = mth.getExitBlocks().get(0);
if (exitBlock.getPredecessors().size() > 1
&& exitBlock.getInstructions().size() == 1
&& !exitBlock.getInstructions().get(0).contains(AType.CATCH_BLOCK)
&& !exitBlock.contains(AFlag.SYNTHETIC)) {
InsnNode returnInsn = exitBlock.getInstructions().get(0);
List<BlockNode> preds = new ArrayList<BlockNode>(exitBlock.getPredecessors());
......
......@@ -8,29 +8,99 @@ import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfCondition;
import jadx.core.dex.regions.IfCondition.Mode;
import jadx.core.dex.regions.IfInfo;
import jadx.core.utils.BlockUtils;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import static jadx.core.utils.BlockUtils.isPathExists;
public class IfMakerHelper {
private IfMakerHelper() {
}
static IfInfo makeIfInfo(BlockNode ifBlock) {
return makeIfInfo(ifBlock, IfCondition.fromIfBlock(ifBlock));
IfNode ifNode = (IfNode) ifBlock.getInstructions().get(0);
IfCondition condition = IfCondition.fromIfNode(ifNode);
IfInfo info = new IfInfo(condition, ifNode.getThenBlock(), ifNode.getElseBlock());
info.setIfBlock(ifBlock);
info.getMergedBlocks().add(ifBlock);
return info;
}
static IfInfo restructureIf(MethodNode mth, BlockNode block, IfInfo info) {
final BlockNode thenBlock = info.getThenBlock();
final BlockNode elseBlock = info.getElseBlock();
// select 'then', 'else' and 'exit' blocks
if (thenBlock.contains(AFlag.RETURN) && elseBlock.contains(AFlag.RETURN)) {
info.setOutBlock(null);
return info;
}
boolean badThen = !allPathsFromIf(thenBlock, info);
boolean badElse = !allPathsFromIf(elseBlock, info);
if (badThen && badElse) {
return null;
}
if (badThen || badElse) {
if (badElse && isPathExists(thenBlock, elseBlock)) {
info = new IfInfo(info.getCondition(), thenBlock, null);
info.setOutBlock(elseBlock);
} else if (badThen && isPathExists(elseBlock, thenBlock)) {
info = IfInfo.invert(info);
info = new IfInfo(info.getCondition(), info.getThenBlock(), null);
info.setOutBlock(thenBlock);
} else if (badElse) {
info = new IfInfo(info.getCondition(), thenBlock, null);
info.setOutBlock(null);
} else {
info = IfInfo.invert(info);
info = new IfInfo(info.getCondition(), info.getThenBlock(), null);
info.setOutBlock(null);
}
} else {
List<BlockNode> thenSC = thenBlock.getCleanSuccessors();
List<BlockNode> elseSC = elseBlock.getCleanSuccessors();
if (thenSC.size() == 1 && sameElements(thenSC, elseSC)) {
info.setOutBlock(thenSC.get(0));
} else if (info.getMergedBlocks().size() == 1
&& block.getDominatesOn().size() == 2) {
info.setOutBlock(BlockUtils.getPathCross(mth, thenBlock, elseBlock));
}
}
if (info.getOutBlock() == null) {
for (BlockNode d : block.getDominatesOn()) {
if (d != thenBlock && d != elseBlock
&& !info.getMergedBlocks().contains(d)
&& isPathExists(thenBlock, d)) {
info.setOutBlock(d);
break;
}
}
}
if (BlockUtils.isBackEdge(block, info.getOutBlock())) {
info.setOutBlock(null);
}
return info;
}
static IfInfo mergeNestedIfNodes(BlockNode block) {
IfInfo info = makeIfInfo(block);
return mergeNestedIfNodes(info);
private static boolean allPathsFromIf(BlockNode block, IfInfo info) {
List<BlockNode> preds = block.getPredecessors();
Set<BlockNode> ifBlocks = info.getMergedBlocks();
return ifBlocks.containsAll(preds);
}
private static IfInfo mergeNestedIfNodes(IfInfo currentIf) {
private static boolean sameElements(Collection<BlockNode> c1, Collection<BlockNode> c2) {
return c1.size() == c2.size() && c1.containsAll(c2);
}
static IfInfo mergeNestedIfNodes(IfInfo currentIf) {
BlockNode curThen = currentIf.getThenBlock();
BlockNode curElse = currentIf.getElseBlock();
if (curThen == curElse) {
......@@ -93,14 +163,6 @@ public class IfMakerHelper {
|| RegionMaker.isEqualPaths(currentIf.getThenBlock(), nextIf.getElseBlock());
}
private static IfInfo makeIfInfo(BlockNode ifBlock, IfCondition condition) {
IfNode ifnode = (IfNode) ifBlock.getInstructions().get(0);
IfInfo info = new IfInfo(condition, ifnode.getThenBlock(), ifnode.getElseBlock());
info.setIfBlock(ifBlock);
info.getMergedBlocks().add(ifBlock);
return info;
}
private static boolean checkConditionBranches(BlockNode from, BlockNode to) {
return from.getCleanSuccessors().size() == 1 && from.getCleanSuccessors().contains(to);
}
......
......@@ -51,6 +51,7 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
private static void processIfRegion(MethodNode mth, IfRegion ifRegion) {
simplifyIfCondition(ifRegion);
moveReturnToThenBlock(mth, ifRegion);
moveBreakToThenBlock(ifRegion);
markElseIfChains(ifRegion);
TernaryMod.makeTernaryInsn(mth, ifRegion);
......@@ -103,6 +104,13 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
}
}
private static void moveBreakToThenBlock(IfRegion ifRegion) {
if (ifRegion.getElseRegion() != null
&& RegionUtils.hasBreakInsn(ifRegion.getElseRegion())) {
invertIfRegion(ifRegion);
}
}
/**
* Mark if-else-if chains
*/
......@@ -124,7 +132,7 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
if (ifRegion.getElseRegion() != null
&& !ifRegion.contains(AFlag.ELSE_IF_CHAIN)
&& !ifRegion.getElseRegion().contains(AFlag.ELSE_IF_CHAIN)
&& RegionUtils.hasExitBlock(ifRegion.getThenRegion())
&& hasBranchTerminator(ifRegion)
&& insnsCount(ifRegion.getThenRegion()) < 2) {
IRegion parent = ifRegion.getParent();
Region newRegion = new Region(parent);
......@@ -138,6 +146,12 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
return false;
}
private static boolean hasBranchTerminator(IfRegion ifRegion) {
// TODO: check for exception throw
return RegionUtils.hasExitBlock(ifRegion.getThenRegion())
|| RegionUtils.hasBreakInsn(ifRegion.getThenRegion());
}
private static void invertIfRegion(IfRegion ifRegion) {
IContainer elseRegion = ifRegion.getElseRegion();
if (elseRegion != null) {
......
......@@ -9,7 +9,9 @@ import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
......@@ -101,13 +103,13 @@ public class BlockUtils {
return false;
}
public static boolean lastInsnType(BlockNode block, InsnType type) {
public static boolean checkLastInsnType(BlockNode block, InsnType expectedType) {
List<InsnNode> insns = block.getInstructions();
if (insns.isEmpty()) {
return false;
}
InsnNode insn = insns.get(insns.size() - 1);
return insn.getType() == type;
return insn.getType() == expectedType;
}
public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) {
......@@ -288,4 +290,15 @@ public class BlockUtils {
}
}
}
public static List<BlockNode> buildSimplePath(BlockNode block) {
List<BlockNode> list = new LinkedList<BlockNode>();
while (block != null
&& block.getCleanSuccessors().size() < 2
&& block.getPredecessors().size() == 1) {
list.add(block);
block = getNextBlock(block);
}
return list.isEmpty() ? Collections.<BlockNode>emptyList() : list;
}
}
......@@ -2,6 +2,7 @@ package jadx.core.utils;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
......@@ -49,6 +50,18 @@ public class RegionUtils {
}
}
public static boolean hasBreakInsn(IContainer container) {
if (container instanceof BlockNode) {
return BlockUtils.checkLastInsnType((BlockNode) container, InsnType.BREAK);
} else if (container instanceof IRegion) {
List<IContainer> blocks = ((IRegion) container).getSubBlocks();
return !blocks.isEmpty()
&& hasBreakInsn(blocks.get(blocks.size() - 1));
} else {
throw new JadxRuntimeException("Unknown container type: " + container);
}
}
public static int insnsCount(IContainer container) {
if (container instanceof BlockNode) {
return ((BlockNode) container).getInstructions().size();
......
......@@ -5,8 +5,7 @@ import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertEquals;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestBreakInLoop extends InternalJadxTest {
......@@ -33,8 +32,12 @@ public class TestBreakInLoop extends InternalJadxTest {
String code = cls.getCode().toString();
System.out.println(code);
assertEquals(1, count(code, "this.f++;"));
assertThat(code, containsString("if (i < b) {"));
assertThat(code, containsString("break;"));
assertThat(code, containsOne("this.f++;"));
// assertThat(code, containsOne("a[i]++;"));
assertThat(code, containsOne("if (i < b) {"));
assertThat(code, containsOne("break;"));
assertThat(code, containsOne("i++;"));
// assertThat(code, countString(0, "else"));
}
}
......@@ -5,7 +5,7 @@ import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestLoopCondition extends InternalJadxTest {
......@@ -48,8 +48,12 @@ public class TestLoopCondition extends InternalJadxTest {
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsString("i < this.f.length()"));
assertThat(code, containsString("list.set(i, \"ABC\")"));
assertThat(code, containsString("list.set(i, \"DEF\")"));
assertThat(code, containsOne("i < this.f.length()"));
assertThat(code, containsOne("list.set(i, \"ABC\")"));
assertThat(code, containsOne("list.set(i, \"DEF\")"));
assertThat(code, containsOne("if (j == 2) {"));
assertThat(code, containsOne("setEnabled(true);"));
assertThat(code, containsOne("setEnabled(false);"));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestLoopCondition3 extends InternalJadxTest {
public static class TestCls {
public static void test(int a, int b, int c) {
while (a < 12) {
if (b + a < 9 && b < 8) {
if (b >= 2 && a > -1 && b < 6) {
System.out.println("OK");
c = b + 1;
}
b = a;
}
c = b;
b++;
b = c;
a++;
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("while (a < 12) {"));
assertThat(code, containsOne("if (b + a < 9 && b < 8) {"));
assertThat(code, containsOne("if (b >= 2 && a > -1 && b < 6) {"));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestLoopCondition4 extends InternalJadxTest {
public static class TestCls {
public static void test() {
int n = -1;
while (n < 0) {
n += 12;
}
while (n > 11) {
n -= 12;
}
System.out.println(n);
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("int n = -1;"));
assertThat(code, containsOne("while (n < 0) {"));
assertThat(code, containsOne("n += 12;"));
assertThat(code, containsOne("while (n > 11) {"));
assertThat(code, containsOne("n -= 12;"));
assertThat(code, containsOne("System.out.println(n);"));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static jadx.tests.utils.JadxMatchers.countString;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestSequentialLoops extends InternalJadxTest {
public static class TestCls {
public int test7(int a, int b) {
int c = b;
int z;
while (true) {
z = c + a;
if (z >= 7) {
break;
}
c = z;
}
while ((z = c + a) >= 7) {
c = z;
}
return c;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, countString(2, "while ("));
assertThat(code, containsOne("break;"));
assertThat(code, containsOne("return c;"));
assertThat(code, not(containsString("else")));
}
}
......@@ -18,10 +18,10 @@ public class TestTryCatch2 extends InternalJadxTest {
synchronized (obj) {
obj.wait(5);
}
return true;
} catch (InterruptedException e) {
return false;
}
return true;
}
}
......@@ -34,12 +34,8 @@ public class TestTryCatch2 extends InternalJadxTest {
assertThat(code, containsString("try {"));
assertThat(code, containsString("synchronized (obj) {"));
assertThat(code, containsString("obj.wait(5);"));
assertThat(code, containsString("return true;"));
assertThat(code, containsString("} catch (InterruptedException e) {"));
// TODO
assertThat(code, containsString(" = false;"));
assertThat(code, containsString(" = true;"));
// assertThat(code, containsString("return false;"));
// assertThat(code, containsString("return true;"));
assertThat(code, containsString("return false;"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment