Commit 1d5368f5 authored by Skylot's avatar Skylot

core: improve out block detection in switch (issue #38)

parent 90fb95e7
......@@ -272,7 +272,7 @@ public class RegionGen extends InsnGen {
boolean addBreak = true;
if (RegionUtils.notEmpty(c)) {
makeRegionIndent(code, c);
if (!RegionUtils.hasExitEdge(c)) {
if (RegionUtils.hasExitEdge(c)) {
addBreak = false;
}
}
......
......@@ -653,38 +653,44 @@ public class RegionMaker {
assert c != null;
blocksMap.put(c, entry.getValue());
}
BitSet succ = BlockUtils.blocksToBitSet(mth, block.getSuccessors());
BitSet domsOn = BlockUtils.blocksToBitSet(mth, block.getDominatesOn());
domsOn.xor(succ); // filter 'out' block
BlockNode defCase = getBlockByOffset(insn.getDefaultCaseOffset(), block.getSuccessors());
if (defCase != null) {
blocksMap.remove(defCase);
}
LoopInfo loop = mth.getLoopForBlock(block);
BitSet outs = new BitSet(mth.getBasicBlocks().size());
outs.or(block.getDomFrontier());
for (BlockNode s : block.getCleanSuccessors()) {
outs.or(s.getDomFrontier());
}
stack.push(sw);
stack.addExits(BlockUtils.bitSetToBlocks(mth, outs));
int outCount = domsOn.cardinality();
if (outCount > 1) {
// filter 'out' block
if (outs.cardinality() > 1) {
// remove exception handlers
BlockUtils.cleanBitSet(mth, domsOn);
outCount = domsOn.cardinality();
BlockUtils.cleanBitSet(mth, outs);
}
if (outCount > 1) {
// filter successors of other blocks
if (outs.cardinality() > 1) {
// filter loop start and successors of other blocks
List<BlockNode> blocks = mth.getBasicBlocks();
for (int i = domsOn.nextSetBit(0); i >= 0; i = domsOn.nextSetBit(i + 1)) {
for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) {
BlockNode b = blocks.get(i);
for (BlockNode s : b.getCleanSuccessors()) {
domsOn.clear(s.getId());
if (b.contains(AFlag.LOOP_START)) {
outs.clear(b.getId());
} else {
for (BlockNode s : b.getCleanSuccessors()) {
outs.clear(s.getId());
}
}
}
outCount = domsOn.cardinality();
}
BlockNode out = null;
if (outCount == 1) {
out = mth.getBasicBlocks().get(domsOn.nextSetBit(0));
} else if (outCount == 0) {
if (loop != null && outs.cardinality() > 1) {
outs.clear(loop.getEnd().getId());
}
if (outs.cardinality() == 0) {
// one or several case blocks are empty,
// run expensive algorithm for find 'out' block
for (BlockNode maybeOut : block.getSuccessors()) {
......@@ -696,18 +702,24 @@ public class RegionMaker {
}
}
if (allReached) {
out = maybeOut;
outs.set(maybeOut.getId());
break;
}
}
}
stack.push(sw);
if (out != null) {
BlockNode out = null;
if (outs.cardinality() == 1) {
out = mth.getBasicBlocks().get(outs.nextSetBit(0));
stack.addExit(out);
} else {
LOG.warn("Can't detect out node for switch block: {} in {}",
block.toString(), mth.toString());
} else if (loop == null && outs.cardinality() > 1) {
LOG.warn("Can't detect out node for switch block: {} in {}", block, mth);
}
if (loop != null) {
// check if 'continue' must be inserted
BlockNode end = loop.getEnd();
if (out != end && out != null) {
insertContinueInSwitch(block, out, end);
}
}
if (!stack.containsExit(defCase)) {
......@@ -727,6 +739,24 @@ public class RegionMaker {
return out;
}
private static void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) {
int endId = end.getId();
for (BlockNode s : block.getCleanSuccessors()) {
if (s.getDomFrontier().get(endId) && s != out) {
// search predecessor of loop end on path from this successor
List<BlockNode> list = BlockUtils.collectBlocksDominatedBy(s, s);
for (BlockNode p : end.getPredecessors()) {
if (list.contains(p)) {
if (p.isSynthetic()) {
p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0));
}
break;
}
}
}
}
}
public void processTryCatchBlocks(MethodNode mth) {
Set<TryCatchBlock> tcs = new HashSet<TryCatchBlock>();
for (ExceptionHandler handler : mth.getExceptionHandlers()) {
......
package jadx.core.utils;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
......@@ -26,9 +25,15 @@ public class RegionUtils {
public static boolean hasExitEdge(IContainer container) {
if (container instanceof BlockNode) {
BlockNode block = (BlockNode) container;
return !block.getSuccessors().isEmpty()
&& !block.contains(AFlag.RETURN);
InsnNode lastInsn = BlockUtils.getLastInsn((BlockNode) container);
if (lastInsn == null) {
return false;
}
InsnType type = lastInsn.getType();
return type == InsnType.RETURN
|| type == InsnType.CONTINUE
|| type == InsnType.BREAK
|| type == InsnType.THROW;
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
List<IContainer> blocks = region.getSubBlocks();
......
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
public class TestSwitchBreak extends IntegrationTest {
public static class TestCls {
public String test(int a) {
String s = "";
loop:
while (a > 0) {
switch (a % 4) {
case 1:
s += "1";
break;
case 3:
case 4:
s += "4";
break;
case 5:
s += "+";
break loop;
}
s += "-";
a--;
}
return s;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("switch (a % 4) {"));
assertEquals(4, count(code, "case "));
assertEquals(3, count(code, "break;"));
// TODO finish break with label from switch
assertThat(code, containsOne("return s + \"+\";"));
}
}
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
public class TestSwitchContinue extends IntegrationTest {
public static class TestCls {
public String test(int a) {
String s = "";
while (a > 0) {
switch (a % 4) {
case 1:
s += "1";
break;
case 3:
case 4:
s += "4";
break;
case 5:
a -= 2;
continue;
}
s += "-";
a--;
}
return s;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("switch (a % 4) {"));
assertEquals(4, count(code, "case "));
assertEquals(2, count(code, "break;"));
assertThat(code, containsOne("a -= 2;"));
assertThat(code, containsOne("continue;"));
}
}
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
public class TestSwitchReturnFromCase extends IntegrationTest {
public static class TestCls {
public void test(int a) {
String s = null;
if (a > 1000) {
return;
}
switch (a % 4) {
case 1:
s = "1";
break;
case 2:
s = "2";
break;
case 3:
case 4:
s = "4";
break;
case 5:
return;
}
s = "5";
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("switch (a % 4) {"));
assertEquals(5, count(code, "case "));
assertEquals(3, count(code, "break;"));
assertThat(code, containsOne("s = \"1\";"));
assertThat(code, containsOne("s = \"2\";"));
assertThat(code, containsOne("s = \"4\";"));
assertThat(code, containsOne("s = \"5\";"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment