Commit 02f9c25f authored by Skylot's avatar Skylot

core: support fall through cases in switch

parent 7fb39881
......@@ -29,5 +29,7 @@ public enum AFlag {
WRAPPED,
ARITH_ONEARG,
FALL_THROUGH,
INCONSISTENT_CODE, // warning about incorrect decompilation
}
......@@ -32,6 +32,8 @@ import jadx.core.utils.exceptions.JadxOverflowException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
......@@ -659,14 +661,47 @@ public class RegionMaker {
}
LoopInfo loop = mth.getLoopForBlock(block);
Map<BlockNode, BlockNode> fallThroughCases = new LinkedHashMap<BlockNode, BlockNode>();
BitSet outs = new BitSet(mth.getBasicBlocks().size());
outs.or(block.getDomFrontier());
for (BlockNode s : block.getCleanSuccessors()) {
outs.or(s.getDomFrontier());
BitSet df = s.getDomFrontier();
// fall through case block
if (df.cardinality() > 1) {
if (df.cardinality() > 2) {
LOG.debug("Unexpected case pattern, block: {}, mth: {}", s, mth);
} else {
BlockNode first = mth.getBasicBlocks().get(df.nextSetBit(0));
BlockNode second = mth.getBasicBlocks().get(df.nextSetBit(first.getId() + 1));
if (second.getDomFrontier().get(first.getId())) {
fallThroughCases.put(s, second);
df = new BitSet(df.size());
df.set(first.getId());
} else if (first.getDomFrontier().get(second.getId())) {
fallThroughCases.put(s, first);
df = new BitSet(df.size());
df.set(second.getId());
}
}
}
outs.or(df);
}
stack.push(sw);
stack.addExits(BlockUtils.bitSetToBlocks(mth, outs));
// check cases order if fall through case exists
if (!fallThroughCases.isEmpty()) {
if (isBadCasesOrder(blocksMap, fallThroughCases)) {
LOG.debug("Fixing incorrect switch cases order");
blocksMap = reOrderSwitchCases(blocksMap, fallThroughCases);
if (isBadCasesOrder(blocksMap, fallThroughCases)) {
LOG.error("Can't fix incorrect switch cases order, method: {}", mth);
mth.add(AFlag.INCONSISTENT_CODE);
}
}
}
// filter 'out' block
if (outs.cardinality() > 1) {
// remove exception handlers
......@@ -677,6 +712,7 @@ public class RegionMaker {
List<BlockNode> blocks = mth.getBasicBlocks();
for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) {
BlockNode b = blocks.get(i);
outs.andNot(b.getDomFrontier());
if (b.contains(AFlag.LOOP_START)) {
outs.clear(b.getId());
} else {
......@@ -726,12 +762,21 @@ public class RegionMaker {
sw.setDefaultCase(makeRegion(defCase, stack));
}
for (Entry<BlockNode, List<Object>> entry : blocksMap.entrySet()) {
BlockNode c = entry.getKey();
if (stack.containsExit(c)) {
BlockNode caseBlock = entry.getKey();
if (stack.containsExit(caseBlock)) {
// empty case block
sw.addCase(entry.getValue(), new Region(stack.peekRegion()));
} else {
sw.addCase(entry.getValue(), makeRegion(c, stack));
BlockNode next = fallThroughCases.get(caseBlock);
stack.addExit(next);
Region caseRegion = makeRegion(caseBlock, stack);
stack.removeExit(next);
if (next != null) {
next.add(AFlag.FALL_THROUGH);
caseRegion.add(AFlag.FALL_THROUGH);
}
sw.addCase(entry.getValue(), caseRegion);
// 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor
}
}
......@@ -739,6 +784,44 @@ public class RegionMaker {
return out;
}
private boolean isBadCasesOrder(final Map<BlockNode, List<Object>> blocksMap,
final Map<BlockNode, BlockNode> fallThroughCases) {
BlockNode nextCaseBlock = null;
for (BlockNode caseBlock : blocksMap.keySet()) {
if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) {
return true;
}
nextCaseBlock = fallThroughCases.get(caseBlock);
}
return nextCaseBlock != null;
}
private Map<BlockNode, List<Object>> reOrderSwitchCases(Map<BlockNode, List<Object>> blocksMap,
final Map<BlockNode, BlockNode> fallThroughCases) {
List<BlockNode> list = new ArrayList<BlockNode>(blocksMap.size());
list.addAll(blocksMap.keySet());
Collections.sort(list, new Comparator<BlockNode>() {
@Override
public int compare(BlockNode a, BlockNode b) {
BlockNode nextA = fallThroughCases.get(a);
if (nextA != null) {
if (b.equals(nextA)) {
return -1;
}
} else if (a.equals(fallThroughCases.get(b))) {
return 1;
}
return 0;
}
});
Map<BlockNode, List<Object>> newBlocksMap = new LinkedHashMap<BlockNode, List<Object>>(blocksMap.size());
for (BlockNode key : list) {
newBlocksMap.put(key, blocksMap.get(key));
}
return newBlocksMap;
}
private static void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) {
int endId = end.getId();
for (BlockNode s : block.getCleanSuccessors()) {
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnContainer;
......@@ -16,7 +19,9 @@ import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -62,25 +67,66 @@ public class RegionMakerVisitor extends AbstractVisitor {
private static final class PostRegionVisitor extends AbstractRegionVisitor {
@Override
public void enterRegion(MethodNode mth, IRegion region) {
public void leaveRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
// merge conditions in loops
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
} else if (region instanceof SwitchRegion) {
// insert 'break' in switch cases (run after try/catch insertion)
SwitchRegion sw = (SwitchRegion) region;
processSwitch(mth, (SwitchRegion) region);
}
}
}
private static void processSwitch(MethodNode mth, SwitchRegion sw) {
for (IContainer c : sw.getBranches()) {
if (c instanceof Region && !RegionUtils.hasExitEdge(c)) {
List<InsnNode> insns = new ArrayList<InsnNode>(1);
insns.add(new InsnNode(InsnType.BREAK, 0));
((Region) c).add(new InsnContainer(insns));
if (!(c instanceof Region)) {
continue;
}
Set<IBlock> blocks = new HashSet<IBlock>();
RegionUtils.getAllRegionBlocks(c, blocks);
if (blocks.isEmpty()) {
addBreakToContainer((Region) c);
continue;
}
for (IBlock block : blocks) {
if (!(block instanceof BlockNode)) {
continue;
}
BlockNode bn = (BlockNode) block;
for (BlockNode s : bn.getCleanSuccessors()) {
if (!blocks.contains(s)
&& !bn.contains(AFlag.SKIP)
&& !s.contains(AFlag.FALL_THROUGH)) {
addBreak(mth, c, bn);
break;
}
}
}
}
}
private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) {
IContainer blockContainer = RegionUtils.getBlockContainer(c, bn);
if (blockContainer instanceof Region) {
addBreakToContainer((Region) blockContainer);
} else if (c instanceof Region) {
addBreakToContainer((Region) c);
} else {
LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth);
}
}
private static void addBreakToContainer(Region c) {
if (RegionUtils.hasExitEdge(c)) {
return;
}
List<InsnNode> insns = new ArrayList<InsnNode>(1);
insns.add(new InsnNode(InsnType.BREAK, 0));
c.add(new InsnContainer(insns));
}
private static void removeSynchronized(MethodNode mth) {
Region startRegion = mth.getRegion();
List<IContainer> subBlocks = startRegion.getSubBlocks();
......
......@@ -95,6 +95,12 @@ final class RegionStack {
}
}
public void removeExit(BlockNode exit) {
if (exit != null) {
curState.exits.remove(exit);
}
}
public boolean containsExit(BlockNode exit) {
return curState.exits.contains(exit);
}
......
......@@ -8,8 +8,6 @@ import jadx.core.dex.nodes.IBranchRegion;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.trycatch.CatchAttr;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.TryCatchBlock;
......@@ -60,8 +58,7 @@ public class RegionUtils {
return null;
}
return insnList.get(insnList.size() - 1);
} else if (container instanceof IfRegion
|| container instanceof SwitchRegion) {
} else if (container instanceof IBranchRegion) {
return null;
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
......@@ -235,6 +232,23 @@ public class RegionUtils {
return true;
}
public static IContainer getBlockContainer(IContainer container, BlockNode block) {
if (container instanceof IBlock) {
return container == block ? container : null;
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
for (IContainer c : region.getSubBlocks()) {
IContainer res = getBlockContainer(c, block);
if (res != null) {
return res instanceof IBlock ? region : res;
}
}
return null;
} else {
throw new JadxRuntimeException("Unknown container type: " + container.getClass());
}
}
public static boolean isDominatedBy(BlockNode dom, IContainer cont) {
if (dom == cont) {
return true;
......
......@@ -60,9 +60,11 @@ public class TestSwitch2 extends IntegrationTest {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, countString(4, "break;"));
// assertThat(code, countString(4, "break;"));
// assertThat(code, countString(2, "return;"));
// TODO: remove redundant returns
// assertThat(code, countString(2, "return;"));
// TODO: remove redundant break and returns
assertThat(code, countString(5, "break;"));
assertThat(code, countString(4, "return;"));
}
}
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
public class TestSwitchWithFallThroughCase extends IntegrationTest {
public static class TestCls {
public String test(int a, boolean b, boolean c) {
String str = "";
switch (a % 4) {
case 1:
str += ">";
if (a == 5 && b) {
if (c) {
str += "1";
} else {
str += "!c";
}
break;
}
case 2:
if (b) {
str += "2";
}
break;
case 3:
break;
default:
str += "default";
break;
}
str += ";";
return str;
}
public void check() {
assertEquals(">1;", test(5, true, true));
assertEquals(">2;", test(1, true, true));
assertEquals(";", test(3, true, true));
assertEquals("default;", test(0, true, true));
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("switch (a % 4) {"));
assertThat(code, containsOne("if (a == 5 && b) {"));
assertThat(code, containsOne("if (b) {"));
}
}
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
public class TestSwitchWithFallThroughCase2 extends IntegrationTest {
public static class TestCls {
public String test(int a, boolean b, boolean c) {
String str = "";
if (a > 0) {
switch (a % 4) {
case 1:
str += ">";
if (a == 5 && b) {
if (c) {
str += "1";
} else {
str += "!c";
}
break;
}
case 2:
if (b) {
str += "2";
}
break;
case 3:
break;
default:
str += "default";
break;
}
str += "+";
}
if (b && c) {
str += "-";
}
return str;
}
public void check() {
assertEquals(">1+-", test(5, true, true));
assertEquals(">2+-", test(1, true, true));
assertEquals("+-", test(3, true, true));
assertEquals("default+-", test(16, true, true));
assertEquals("-", test(-1, true, true));
}
}
@Test
public void test() {
setOutputCFG();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("switch (a % 4) {"));
assertThat(code, containsOne("if (a == 5 && b) {"));
assertThat(code, containsOne("if (b) {"));
}
}
......@@ -62,7 +62,10 @@ public class TestSwitchWithTryCatch extends IntegrationTest {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, countString(3, "break;"));
// assertThat(code, countString(3, "break;"));
assertThat(code, countString(4, "return;"));
// TODO: remove redundant break
assertThat(code, countString(4, "break;"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment