Commit 811b0e7f authored by Skylot's avatar Skylot

core: fix 'break' insertion for switch/case blocks (fix #41)

parent 08ea61f4
......@@ -257,30 +257,17 @@ public class RegionGen extends InsnGen {
}
code.add(':');
}
makeCaseBlock(c, code);
makeRegionIndent(code, c);
}
if (sw.getDefaultCase() != null) {
code.startLine("default:");
makeCaseBlock(sw.getDefaultCase(), code);
makeRegionIndent(code, sw.getDefaultCase());
}
code.decIndent();
code.startLine('}');
return code;
}
private void makeCaseBlock(IContainer c, CodeWriter code) throws CodegenException {
boolean addBreak = true;
if (RegionUtils.notEmpty(c)) {
makeRegionIndent(code, c);
if (RegionUtils.hasExitEdge(c)) {
addBreak = false;
}
}
if (addBreak) {
code.startLine().addIndent().add("break;");
}
}
private void makeTryCatch(TryCatchRegion region, CodeWriter code) throws CodegenException {
code.startLine("try {");
makeRegionIndent(code, region.getTryRegion());
......
package jadx.core.dex.nodes;
import java.util.List;
public interface IBranchRegion extends IRegion {
/**
* Return list of branches in this region.
* NOTE: Contains 'null' elements for indicate empty branches.
*/
List<IContainer> getBranches();
}
package jadx.core.dex.regions;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBranchRegion;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
......@@ -8,7 +9,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public final class SwitchRegion extends AbstractRegion {
public final class SwitchRegion extends AbstractRegion implements IBranchRegion {
private final BlockNode header;
......@@ -60,6 +61,14 @@ public final class SwitchRegion extends AbstractRegion {
}
@Override
public List<IContainer> getBranches() {
List<IContainer> branches = new ArrayList<IContainer>(cases.size() + 1);
branches.addAll(cases);
branches.add(defCase);
return Collections.unmodifiableList(branches);
}
@Override
public String baseString() {
return header.baseString();
}
......
package jadx.core.dex.regions;
import jadx.core.dex.nodes.IBranchRegion;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.trycatch.ExceptionHandler;
......@@ -12,7 +13,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public final class TryCatchRegion extends AbstractRegion {
public final class TryCatchRegion extends AbstractRegion implements IBranchRegion {
private final IContainer tryRegion;
private Map<ExceptionHandler, IContainer> catchRegions = Collections.emptyMap();
......@@ -72,6 +73,11 @@ public final class TryCatchRegion extends AbstractRegion {
}
@Override
public List<IContainer> getBranches() {
return getSubBlocks();
}
@Override
public String baseString() {
return tryRegion.baseString();
}
......
package jadx.core.dex.regions.conditions;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBranchRegion;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.regions.AbstractRegion;
......@@ -9,7 +10,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public final class IfRegion extends AbstractRegion {
public final class IfRegion extends AbstractRegion implements IBranchRegion {
private final BlockNode header;
......@@ -90,6 +91,14 @@ public final class IfRegion extends AbstractRegion {
}
@Override
public List<IContainer> getBranches() {
List<IContainer> branches = new ArrayList<IContainer>(2);
branches.add(thenRegion);
branches.add(elseRegion);
return Collections.unmodifiableList(branches);
}
@Override
public boolean replaceSubBlock(IContainer oldBlock, IContainer newBlock) {
if (oldBlock == thenRegion) {
thenRegion = newBlock;
......
......@@ -6,6 +6,7 @@ import jadx.core.dex.attributes.IAttributeNode;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
......@@ -113,7 +114,7 @@ public class DotGraphVisitor extends AbstractVisitor {
processRegion(mth, h.getHandlerRegion());
}
}
Set<BlockNode> regionsBlocks = new HashSet<BlockNode>(mth.getBasicBlocks().size());
Set<IBlock> regionsBlocks = new HashSet<IBlock>(mth.getBasicBlocks().size());
RegionUtils.getAllRegionBlocks(mth.getRegion(), regionsBlocks);
for (ExceptionHandler handler : mth.getExceptionHandlers()) {
IContainer handlerRegion = handler.getHandlerRegion();
......@@ -147,6 +148,8 @@ public class DotGraphVisitor extends AbstractVisitor {
dot.startLine('}');
} else if (region instanceof BlockNode) {
processBlock(mth, (BlockNode) region, false);
} else if (region instanceof IBlock) {
processIBlock(mth, (IBlock) region, false);
}
}
......@@ -189,6 +192,24 @@ public class DotGraphVisitor extends AbstractVisitor {
}
}
private void processIBlock(MethodNode mth, IBlock block, boolean error) {
String attrs = attributesString(block);
dot.startLine(makeName(block));
dot.add(" [shape=record,");
if (error) {
dot.add("color=red,");
}
dot.add("label=\"{");
if (attrs.length() != 0) {
dot.add(attrs);
}
String insns = insertInsns(mth, block);
if (insns.length() != 0) {
dot.add('|').add(insns);
}
dot.add("}\"];");
}
private void addEdge(BlockNode from, BlockNode to, String style) {
conn.startLine(makeName(from)).add(" -> ").add(makeName(to));
conn.add(style);
......@@ -207,13 +228,15 @@ public class DotGraphVisitor extends AbstractVisitor {
String name;
if (c instanceof BlockNode) {
name = "Node_" + ((BlockNode) c).getId();
} else if (c instanceof IBlock) {
name = "Node_" + c.getClass().getSimpleName() + "_" + c.hashCode();
} else {
name = "cluster_" + c.getClass().getSimpleName() + "_" + c.hashCode();
}
return name;
}
private String insertInsns(MethodNode mth, BlockNode block) {
private String insertInsns(MethodNode mth, IBlock block) {
if (rawInsn) {
StringBuilder str = new StringBuilder();
for (InsnNode insn : block.getInstructions()) {
......
......@@ -9,12 +9,11 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.VarName;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IBranchRegion;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.ForLoop;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.regions.loops.LoopType;
......@@ -331,8 +330,7 @@ public class ProcessVariables extends AbstractVisitor {
return id;
}
for (IContainer c : region.getSubBlocks()) {
if (c instanceof IfRegion
|| c instanceof SwitchRegion) {
if (c instanceof IBranchRegion) {
// on branch set for all inner regions same order id
id = calculateOrder(c, regionsOrder, inc ? id + 1 : id, false);
} else {
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnContainer;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.SynchronizedRegion;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.InstructionRemover;
import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxException;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
......@@ -22,6 +27,8 @@ import org.slf4j.LoggerFactory;
public class RegionMakerVisitor extends AbstractVisitor {
private static final Logger LOG = LoggerFactory.getLogger(RegionMakerVisitor.class);
private static final PostRegionVisitor POST_REGION_VISITOR = new PostRegionVisitor();
@Override
public void visit(MethodNode mth) throws JadxException {
if (mth.isNoCode()) {
......@@ -44,18 +51,7 @@ public class RegionMakerVisitor extends AbstractVisitor {
// make try-catch regions
ProcessTryCatchRegions.process(mth);
// merge conditions in loops
if (mth.getLoopsCount() != 0) {
DepthRegionTraversal.traverseAll(mth, new AbstractRegionVisitor() {
@Override
public void enterRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
}
}
});
}
DepthRegionTraversal.traverse(mth, POST_REGION_VISITOR);
CleanRegions.process(mth);
......@@ -64,6 +60,27 @@ public class RegionMakerVisitor extends AbstractVisitor {
}
}
private static final class PostRegionVisitor extends AbstractRegionVisitor {
@Override
public void enterRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
// merge conditions in loops
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
} else if (region instanceof SwitchRegion) {
// insert 'break' in switch cases (run after try/catch insertion)
SwitchRegion sw = (SwitchRegion) region;
for (IContainer c : sw.getBranches()) {
if (c instanceof Region && !RegionUtils.hasExitEdge(c)) {
List<InsnNode> insns = new ArrayList<InsnNode>(1);
insns.add(new InsnNode(InsnType.BREAK, 0));
((Region) c).add(new InsnContainer(insns));
}
}
}
}
}
private static void removeSynchronized(MethodNode mth) {
Region startRegion = mth.getRegion();
List<IContainer> subBlocks = startRegion.getSubBlocks();
......
......@@ -4,13 +4,11 @@ import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IBranchRegion;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.TryCatchRegion;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException;
......@@ -28,11 +26,11 @@ public class ReturnVisitor extends AbstractVisitor {
public void visit(MethodNode mth) throws JadxException {
// remove useless returns in void methods
if (mth.getReturnType().equals(ArgType.VOID)) {
DepthRegionTraversal.traverseAll(mth, new Process());
DepthRegionTraversal.traverse(mth, new ReturnRemoverVisitor());
}
}
private static final class Process extends TracedRegionVisitor {
private static final class ReturnRemoverVisitor extends TracedRegionVisitor {
@Override
public void processBlockTraced(MethodNode mth, IBlock container, IRegion currentRegion) {
if (container.getClass() != BlockNode.class) {
......@@ -72,9 +70,7 @@ public class ReturnVisitor extends AbstractVisitor {
IContainer curContainer = block;
for (IRegion region : regionStack) {
// ignore paths on other branches
if (region instanceof IfRegion
|| region instanceof SwitchRegion
|| region instanceof TryCatchRegion) {
if (region instanceof IBranchRegion) {
curContainer = region;
continue;
}
......@@ -100,8 +96,8 @@ public class ReturnVisitor extends AbstractVisitor {
* don't count one 'return' instruction (it will be removed later).
*/
private static boolean isEmpty(IContainer container) {
if (container instanceof BlockNode) {
BlockNode block = (BlockNode) container;
if (container instanceof IBlock) {
IBlock block = (IBlock) container;
return block.getInstructions().isEmpty() || block.contains(AFlag.RETURN);
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
......
......@@ -11,6 +11,7 @@ import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.IfCondition;
......@@ -133,13 +134,13 @@ public class BlockUtils {
return false;
}
public static boolean checkLastInsnType(BlockNode block, InsnType expectedType) {
public static boolean checkLastInsnType(IBlock block, InsnType expectedType) {
InsnNode insn = getLastInsn(block);
return insn != null && insn.getType() == expectedType;
}
@Nullable
public static InsnNode getLastInsn(BlockNode block) {
public static InsnNode getLastInsn(IBlock block) {
List<InsnNode> insns = block.getInstructions();
if (insns.isEmpty()) {
return null;
......
......@@ -3,6 +3,8 @@ package jadx.core.utils;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IBranchRegion;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
......@@ -24,8 +26,8 @@ public class RegionUtils {
}
public static boolean hasExitEdge(IContainer container) {
if (container instanceof BlockNode) {
InsnNode lastInsn = BlockUtils.getLastInsn((BlockNode) container);
if (container instanceof IBlock) {
InsnNode lastInsn = BlockUtils.getLastInsn((IBlock) container);
if (lastInsn == null) {
return false;
}
......@@ -34,6 +36,13 @@ public class RegionUtils {
|| type == InsnType.CONTINUE
|| type == InsnType.BREAK
|| type == InsnType.THROW;
} else if (container instanceof IBranchRegion) {
for (IContainer br : ((IBranchRegion) container).getBranches()) {
if (br == null || !hasExitEdge(br)) {
return false;
}
}
return true;
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
List<IContainer> blocks = region.getSubBlocks();
......@@ -44,8 +53,8 @@ public class RegionUtils {
}
public static InsnNode getLastInsn(IContainer container) {
if (container instanceof BlockNode) {
BlockNode block = (BlockNode) container;
if (container instanceof IBlock) {
IBlock block = (IBlock) container;
List<InsnNode> insnList = block.getInstructions();
if (insnList.isEmpty()) {
return null;
......@@ -72,6 +81,8 @@ public class RegionUtils {
public static boolean hasExitBlock(IContainer container) {
if (container instanceof BlockNode) {
return ((BlockNode) container).getSuccessors().isEmpty();
} else if (container instanceof IBlock) {
return true;
} else if (container instanceof IRegion) {
List<IContainer> blocks = ((IRegion) container).getSubBlocks();
return !blocks.isEmpty()
......@@ -82,8 +93,8 @@ public class RegionUtils {
}
public static boolean hasBreakInsn(IContainer container) {
if (container instanceof BlockNode) {
return BlockUtils.checkLastInsnType((BlockNode) container, InsnType.BREAK);
if (container instanceof IBlock) {
return BlockUtils.checkLastInsnType((IBlock) container, InsnType.BREAK);
} else if (container instanceof IRegion) {
List<IContainer> blocks = ((IRegion) container).getSubBlocks();
return !blocks.isEmpty()
......@@ -94,8 +105,8 @@ public class RegionUtils {
}
public static int insnsCount(IContainer container) {
if (container instanceof BlockNode) {
return ((BlockNode) container).getInstructions().size();
if (container instanceof IBlock) {
return ((IBlock) container).getInstructions().size();
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
int count = 0;
......@@ -113,8 +124,8 @@ public class RegionUtils {
}
public static boolean notEmpty(IContainer container) {
if (container instanceof BlockNode) {
return !((BlockNode) container).getInstructions().isEmpty();
if (container instanceof IBlock) {
return !((IBlock) container).getInstructions().isEmpty();
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
for (IContainer block : region.getSubBlocks()) {
......@@ -128,9 +139,9 @@ public class RegionUtils {
}
}
public static void getAllRegionBlocks(IContainer container, Set<BlockNode> blocks) {
if (container instanceof BlockNode) {
blocks.add((BlockNode) container);
public static void getAllRegionBlocks(IContainer container, Set<IBlock> blocks) {
if (container instanceof IBlock) {
blocks.add((IBlock) container);
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
for (IContainer block : region.getSubBlocks()) {
......@@ -142,7 +153,7 @@ public class RegionUtils {
}
public static boolean isRegionContainsBlock(IContainer container, BlockNode block) {
if (container instanceof BlockNode) {
if (container instanceof IBlock) {
return container == block;
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
......@@ -231,6 +242,8 @@ public class RegionUtils {
if (cont instanceof BlockNode) {
BlockNode block = (BlockNode) cont;
return block.isDominator(dom);
} else if (cont instanceof IBlock) {
return false;
} else if (cont instanceof IRegion) {
IRegion region = (IRegion) cont;
for (IContainer c : region.getSubBlocks()) {
......@@ -250,6 +263,8 @@ public class RegionUtils {
}
if (cont instanceof BlockNode) {
return BlockUtils.isPathExists(block, (BlockNode) cont);
} else if (cont instanceof IBlock) {
return false;
} else if (cont instanceof IRegion) {
IRegion region = (IRegion) cont;
for (IContainer c : region.getSubBlocks()) {
......
......@@ -24,8 +24,7 @@ public class CountString extends SubstringMatcher {
@Override
public void describeMismatchSafely(String item, Description mismatchDescription) {
mismatchDescription.appendText("found ").appendValue(count(item))
.appendText(" in \"").appendText(item).appendText("\"");
mismatchDescription.appendText("found ").appendValue(count(item));
}
private int count(String string) {
......
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.junit.Assert.assertThat;
public class TestSwitch2 extends IntegrationTest {
public static class TestCls {
boolean isLongtouchable;
boolean isMultiTouchZoom;
boolean isCanZoomIn;
boolean isCanZoomOut;
boolean isScrolling;
float multiTouchZoomOldDist;
void test(int action) {
switch (action & 255) {
case 0:
this.isLongtouchable = true;
break;
case 1:
case 6:
if (this.isMultiTouchZoom) {
this.isMultiTouchZoom = false;
}
break;
case 2:
if (this.isMultiTouchZoom) {
float dist = multiTouchZoomOldDist;
if (Math.abs(dist - this.multiTouchZoomOldDist) > 10.0f) {
float scale = dist / this.multiTouchZoomOldDist;
if ((scale > 1.0f && this.isCanZoomIn) || (scale < 1.0f && this.isCanZoomOut)) {
this.multiTouchZoomOldDist = dist;
}
}
return;
}
break;
case 5:
this.multiTouchZoomOldDist = action;
if (this.multiTouchZoomOldDist > 10.0f) {
this.isMultiTouchZoom = true;
this.isLongtouchable = false;
return;
}
break;
}
if (this.isScrolling && action == 1) {
this.isScrolling = false;
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, countString(4, "break;"));
// TODO: remove redundant returns
// assertThat(code, countString(2, "return;"));
}
}
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.junit.Assert.assertThat;
public class TestSwitchWithTryCatch extends IntegrationTest {
public static class TestCls {
void test(int a) {
switch (a) {
case 0:
try {
exc();
return;
} catch (Exception e) {
e.printStackTrace();
return;
}
// no break;
case 1:
try {
exc();
return;
} catch (Exception e) {
e.printStackTrace();
}
break;
case 2:
try {
exc();
} catch (Exception e) {
e.printStackTrace();
return;
}
break;
case 3:
try {
exc();
} catch (Exception e) {
e.printStackTrace();
}
break;
}
if (a == 10) {
System.out.println(a);
}
}
private void exc() throws Exception {
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, countString(3, "break;"));
assertThat(code, countString(4, "return;"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment