Commit 02f9c25f authored by Skylot's avatar Skylot

core: support fall through cases in switch

parent 7fb39881
...@@ -29,5 +29,7 @@ public enum AFlag { ...@@ -29,5 +29,7 @@ public enum AFlag {
WRAPPED, WRAPPED,
ARITH_ONEARG, ARITH_ONEARG,
FALL_THROUGH,
INCONSISTENT_CODE, // warning about incorrect decompilation INCONSISTENT_CODE, // warning about incorrect decompilation
} }
...@@ -32,6 +32,8 @@ import jadx.core.utils.exceptions.JadxOverflowException; ...@@ -32,6 +32,8 @@ import jadx.core.utils.exceptions.JadxOverflowException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet; import java.util.BitSet;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
...@@ -659,14 +661,47 @@ public class RegionMaker { ...@@ -659,14 +661,47 @@ public class RegionMaker {
} }
LoopInfo loop = mth.getLoopForBlock(block); LoopInfo loop = mth.getLoopForBlock(block);
Map<BlockNode, BlockNode> fallThroughCases = new LinkedHashMap<BlockNode, BlockNode>();
BitSet outs = new BitSet(mth.getBasicBlocks().size()); BitSet outs = new BitSet(mth.getBasicBlocks().size());
outs.or(block.getDomFrontier()); outs.or(block.getDomFrontier());
for (BlockNode s : block.getCleanSuccessors()) { for (BlockNode s : block.getCleanSuccessors()) {
outs.or(s.getDomFrontier()); BitSet df = s.getDomFrontier();
// fall through case block
if (df.cardinality() > 1) {
if (df.cardinality() > 2) {
LOG.debug("Unexpected case pattern, block: {}, mth: {}", s, mth);
} else {
BlockNode first = mth.getBasicBlocks().get(df.nextSetBit(0));
BlockNode second = mth.getBasicBlocks().get(df.nextSetBit(first.getId() + 1));
if (second.getDomFrontier().get(first.getId())) {
fallThroughCases.put(s, second);
df = new BitSet(df.size());
df.set(first.getId());
} else if (first.getDomFrontier().get(second.getId())) {
fallThroughCases.put(s, first);
df = new BitSet(df.size());
df.set(second.getId());
}
}
}
outs.or(df);
} }
stack.push(sw); stack.push(sw);
stack.addExits(BlockUtils.bitSetToBlocks(mth, outs)); stack.addExits(BlockUtils.bitSetToBlocks(mth, outs));
// check cases order if fall through case exists
if (!fallThroughCases.isEmpty()) {
if (isBadCasesOrder(blocksMap, fallThroughCases)) {
LOG.debug("Fixing incorrect switch cases order");
blocksMap = reOrderSwitchCases(blocksMap, fallThroughCases);
if (isBadCasesOrder(blocksMap, fallThroughCases)) {
LOG.error("Can't fix incorrect switch cases order, method: {}", mth);
mth.add(AFlag.INCONSISTENT_CODE);
}
}
}
// filter 'out' block // filter 'out' block
if (outs.cardinality() > 1) { if (outs.cardinality() > 1) {
// remove exception handlers // remove exception handlers
...@@ -677,6 +712,7 @@ public class RegionMaker { ...@@ -677,6 +712,7 @@ public class RegionMaker {
List<BlockNode> blocks = mth.getBasicBlocks(); List<BlockNode> blocks = mth.getBasicBlocks();
for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) { for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) {
BlockNode b = blocks.get(i); BlockNode b = blocks.get(i);
outs.andNot(b.getDomFrontier());
if (b.contains(AFlag.LOOP_START)) { if (b.contains(AFlag.LOOP_START)) {
outs.clear(b.getId()); outs.clear(b.getId());
} else { } else {
...@@ -726,12 +762,21 @@ public class RegionMaker { ...@@ -726,12 +762,21 @@ public class RegionMaker {
sw.setDefaultCase(makeRegion(defCase, stack)); sw.setDefaultCase(makeRegion(defCase, stack));
} }
for (Entry<BlockNode, List<Object>> entry : blocksMap.entrySet()) { for (Entry<BlockNode, List<Object>> entry : blocksMap.entrySet()) {
BlockNode c = entry.getKey(); BlockNode caseBlock = entry.getKey();
if (stack.containsExit(c)) { if (stack.containsExit(caseBlock)) {
// empty case block // empty case block
sw.addCase(entry.getValue(), new Region(stack.peekRegion())); sw.addCase(entry.getValue(), new Region(stack.peekRegion()));
} else { } else {
sw.addCase(entry.getValue(), makeRegion(c, stack)); BlockNode next = fallThroughCases.get(caseBlock);
stack.addExit(next);
Region caseRegion = makeRegion(caseBlock, stack);
stack.removeExit(next);
if (next != null) {
next.add(AFlag.FALL_THROUGH);
caseRegion.add(AFlag.FALL_THROUGH);
}
sw.addCase(entry.getValue(), caseRegion);
// 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor
} }
} }
...@@ -739,6 +784,44 @@ public class RegionMaker { ...@@ -739,6 +784,44 @@ public class RegionMaker {
return out; return out;
} }
private boolean isBadCasesOrder(final Map<BlockNode, List<Object>> blocksMap,
final Map<BlockNode, BlockNode> fallThroughCases) {
BlockNode nextCaseBlock = null;
for (BlockNode caseBlock : blocksMap.keySet()) {
if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) {
return true;
}
nextCaseBlock = fallThroughCases.get(caseBlock);
}
return nextCaseBlock != null;
}
private Map<BlockNode, List<Object>> reOrderSwitchCases(Map<BlockNode, List<Object>> blocksMap,
final Map<BlockNode, BlockNode> fallThroughCases) {
List<BlockNode> list = new ArrayList<BlockNode>(blocksMap.size());
list.addAll(blocksMap.keySet());
Collections.sort(list, new Comparator<BlockNode>() {
@Override
public int compare(BlockNode a, BlockNode b) {
BlockNode nextA = fallThroughCases.get(a);
if (nextA != null) {
if (b.equals(nextA)) {
return -1;
}
} else if (a.equals(fallThroughCases.get(b))) {
return 1;
}
return 0;
}
});
Map<BlockNode, List<Object>> newBlocksMap = new LinkedHashMap<BlockNode, List<Object>>(blocksMap.size());
for (BlockNode key : list) {
newBlocksMap.put(key, blocksMap.get(key));
}
return newBlocksMap;
}
private static void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) { private static void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) {
int endId = end.getId(); int endId = end.getId();
for (BlockNode s : block.getCleanSuccessors()) { for (BlockNode s : block.getCleanSuccessors()) {
......
package jadx.core.dex.visitors.regions; package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer; import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnContainer; import jadx.core.dex.nodes.InsnContainer;
...@@ -16,7 +19,9 @@ import jadx.core.utils.RegionUtils; ...@@ -16,7 +19,9 @@ import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -62,25 +67,66 @@ public class RegionMakerVisitor extends AbstractVisitor { ...@@ -62,25 +67,66 @@ public class RegionMakerVisitor extends AbstractVisitor {
private static final class PostRegionVisitor extends AbstractRegionVisitor { private static final class PostRegionVisitor extends AbstractRegionVisitor {
@Override @Override
public void enterRegion(MethodNode mth, IRegion region) { public void leaveRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) { if (region instanceof LoopRegion) {
// merge conditions in loops // merge conditions in loops
LoopRegion loop = (LoopRegion) region; LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition(); loop.mergePreCondition();
} else if (region instanceof SwitchRegion) { } else if (region instanceof SwitchRegion) {
// insert 'break' in switch cases (run after try/catch insertion) // insert 'break' in switch cases (run after try/catch insertion)
SwitchRegion sw = (SwitchRegion) region; processSwitch(mth, (SwitchRegion) region);
}
}
}
private static void processSwitch(MethodNode mth, SwitchRegion sw) {
for (IContainer c : sw.getBranches()) { for (IContainer c : sw.getBranches()) {
if (c instanceof Region && !RegionUtils.hasExitEdge(c)) { if (!(c instanceof Region)) {
List<InsnNode> insns = new ArrayList<InsnNode>(1); continue;
insns.add(new InsnNode(InsnType.BREAK, 0)); }
((Region) c).add(new InsnContainer(insns)); Set<IBlock> blocks = new HashSet<IBlock>();
RegionUtils.getAllRegionBlocks(c, blocks);
if (blocks.isEmpty()) {
addBreakToContainer((Region) c);
continue;
}
for (IBlock block : blocks) {
if (!(block instanceof BlockNode)) {
continue;
}
BlockNode bn = (BlockNode) block;
for (BlockNode s : bn.getCleanSuccessors()) {
if (!blocks.contains(s)
&& !bn.contains(AFlag.SKIP)
&& !s.contains(AFlag.FALL_THROUGH)) {
addBreak(mth, c, bn);
break;
} }
} }
} }
} }
} }
private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) {
IContainer blockContainer = RegionUtils.getBlockContainer(c, bn);
if (blockContainer instanceof Region) {
addBreakToContainer((Region) blockContainer);
} else if (c instanceof Region) {
addBreakToContainer((Region) c);
} else {
LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth);
}
}
private static void addBreakToContainer(Region c) {
if (RegionUtils.hasExitEdge(c)) {
return;
}
List<InsnNode> insns = new ArrayList<InsnNode>(1);
insns.add(new InsnNode(InsnType.BREAK, 0));
c.add(new InsnContainer(insns));
}
private static void removeSynchronized(MethodNode mth) { private static void removeSynchronized(MethodNode mth) {
Region startRegion = mth.getRegion(); Region startRegion = mth.getRegion();
List<IContainer> subBlocks = startRegion.getSubBlocks(); List<IContainer> subBlocks = startRegion.getSubBlocks();
......
...@@ -95,6 +95,12 @@ final class RegionStack { ...@@ -95,6 +95,12 @@ final class RegionStack {
} }
} }
public void removeExit(BlockNode exit) {
if (exit != null) {
curState.exits.remove(exit);
}
}
public boolean containsExit(BlockNode exit) { public boolean containsExit(BlockNode exit) {
return curState.exits.contains(exit); return curState.exits.contains(exit);
} }
......
...@@ -8,8 +8,6 @@ import jadx.core.dex.nodes.IBranchRegion; ...@@ -8,8 +8,6 @@ import jadx.core.dex.nodes.IBranchRegion;
import jadx.core.dex.nodes.IContainer; import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.trycatch.CatchAttr; import jadx.core.dex.trycatch.CatchAttr;
import jadx.core.dex.trycatch.ExceptionHandler; import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.TryCatchBlock; import jadx.core.dex.trycatch.TryCatchBlock;
...@@ -60,8 +58,7 @@ public class RegionUtils { ...@@ -60,8 +58,7 @@ public class RegionUtils {
return null; return null;
} }
return insnList.get(insnList.size() - 1); return insnList.get(insnList.size() - 1);
} else if (container instanceof IfRegion } else if (container instanceof IBranchRegion) {
|| container instanceof SwitchRegion) {
return null; return null;
} else if (container instanceof IRegion) { } else if (container instanceof IRegion) {
IRegion region = (IRegion) container; IRegion region = (IRegion) container;
...@@ -235,6 +232,23 @@ public class RegionUtils { ...@@ -235,6 +232,23 @@ public class RegionUtils {
return true; return true;
} }
public static IContainer getBlockContainer(IContainer container, BlockNode block) {
if (container instanceof IBlock) {
return container == block ? container : null;
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
for (IContainer c : region.getSubBlocks()) {
IContainer res = getBlockContainer(c, block);
if (res != null) {
return res instanceof IBlock ? region : res;
}
}
return null;
} else {
throw new JadxRuntimeException("Unknown container type: " + container.getClass());
}
}
public static boolean isDominatedBy(BlockNode dom, IContainer cont) { public static boolean isDominatedBy(BlockNode dom, IContainer cont) {
if (dom == cont) { if (dom == cont) {
return true; return true;
......
...@@ -60,9 +60,11 @@ public class TestSwitch2 extends IntegrationTest { ...@@ -60,9 +60,11 @@ public class TestSwitch2 extends IntegrationTest {
ClassNode cls = getClassNode(TestCls.class); ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, countString(4, "break;")); // assertThat(code, countString(4, "break;"));
// assertThat(code, countString(2, "return;"));
// TODO: remove redundant returns // TODO: remove redundant break and returns
// assertThat(code, countString(2, "return;")); assertThat(code, countString(5, "break;"));
assertThat(code, countString(4, "return;"));
} }
} }
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
public class TestSwitchWithFallThroughCase extends IntegrationTest {
public static class TestCls {
public String test(int a, boolean b, boolean c) {
String str = "";
switch (a % 4) {
case 1:
str += ">";
if (a == 5 && b) {
if (c) {
str += "1";
} else {
str += "!c";
}
break;
}
case 2:
if (b) {
str += "2";
}
break;
case 3:
break;
default:
str += "default";
break;
}
str += ";";
return str;
}
public void check() {
assertEquals(">1;", test(5, true, true));
assertEquals(">2;", test(1, true, true));
assertEquals(";", test(3, true, true));
assertEquals("default;", test(0, true, true));
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("switch (a % 4) {"));
assertThat(code, containsOne("if (a == 5 && b) {"));
assertThat(code, containsOne("if (b) {"));
}
}
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
public class TestSwitchWithFallThroughCase2 extends IntegrationTest {
public static class TestCls {
public String test(int a, boolean b, boolean c) {
String str = "";
if (a > 0) {
switch (a % 4) {
case 1:
str += ">";
if (a == 5 && b) {
if (c) {
str += "1";
} else {
str += "!c";
}
break;
}
case 2:
if (b) {
str += "2";
}
break;
case 3:
break;
default:
str += "default";
break;
}
str += "+";
}
if (b && c) {
str += "-";
}
return str;
}
public void check() {
assertEquals(">1+-", test(5, true, true));
assertEquals(">2+-", test(1, true, true));
assertEquals("+-", test(3, true, true));
assertEquals("default+-", test(16, true, true));
assertEquals("-", test(-1, true, true));
}
}
@Test
public void test() {
setOutputCFG();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("switch (a % 4) {"));
assertThat(code, containsOne("if (a == 5 && b) {"));
assertThat(code, containsOne("if (b) {"));
}
}
...@@ -62,7 +62,10 @@ public class TestSwitchWithTryCatch extends IntegrationTest { ...@@ -62,7 +62,10 @@ public class TestSwitchWithTryCatch extends IntegrationTest {
ClassNode cls = getClassNode(TestCls.class); ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, countString(3, "break;")); // assertThat(code, countString(3, "break;"));
assertThat(code, countString(4, "return;")); assertThat(code, countString(4, "return;"));
// TODO: remove redundant break
assertThat(code, countString(4, "break;"));
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment