Commit 6e66dc25 authored by Skylot's avatar Skylot

fix: additional checks for loop exit edges and 'for' conversion (#483)

parent 999793c9
...@@ -140,7 +140,7 @@ public class IfRegionVisitor extends AbstractVisitor { ...@@ -140,7 +140,7 @@ public class IfRegionVisitor extends AbstractVisitor {
|| ifRegion.getElseRegion().contains(AFlag.ELSE_IF_CHAIN)) { || ifRegion.getElseRegion().contains(AFlag.ELSE_IF_CHAIN)) {
return false; return false;
} }
if (!hasBranchTerminator(ifRegion.getThenRegion())) { if (!RegionUtils.hasExitBlock(ifRegion.getThenRegion())) {
return false; return false;
} }
// code style check: // code style check:
...@@ -162,12 +162,6 @@ public class IfRegionVisitor extends AbstractVisitor { ...@@ -162,12 +162,6 @@ public class IfRegionVisitor extends AbstractVisitor {
return false; return false;
} }
private static boolean hasBranchTerminator(IContainer region) {
// TODO: check for exception throw
return RegionUtils.hasExitBlock(region)
|| RegionUtils.hasBreakInsn(region);
}
private static void invertIfRegion(IfRegion ifRegion) { private static void invertIfRegion(IfRegion ifRegion) {
IContainer elseRegion = ifRegion.getElseRegion(); IContainer elseRegion = ifRegion.getElseRegion();
if (elseRegion != null) { if (elseRegion != null) {
......
...@@ -7,6 +7,7 @@ import java.util.LinkedHashMap; ...@@ -7,6 +7,7 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import org.slf4j.Logger; import org.slf4j.Logger;
...@@ -192,7 +193,7 @@ public class RegionMaker { ...@@ -192,7 +193,7 @@ public class RegionMaker {
// add 'break' instruction before path cross between main loop exit and sub-exit // add 'break' instruction before path cross between main loop exit and sub-exit
for (Edge exitEdge : loop.getExitEdges()) { for (Edge exitEdge : loop.getExitEdges()) {
if (exitBlocks.contains(exitEdge.getSource())) { if (exitBlocks.contains(exitEdge.getSource())) {
insertBreak(stack, loopExit, exitEdge); insertLoopBreak(stack, loop, loopExit, exitEdge);
} }
} }
} }
...@@ -288,6 +289,9 @@ public class RegionMaker { ...@@ -288,6 +289,9 @@ public class RegionMaker {
} }
} }
} }
if (found && !checkLoopExits(loop, block)) {
found = false;
}
if (found) { if (found) {
return loopRegion; return loopRegion;
} }
...@@ -296,6 +300,32 @@ public class RegionMaker { ...@@ -296,6 +300,32 @@ public class RegionMaker {
return null; return null;
} }
private boolean checkLoopExits(LoopInfo loop, BlockNode mainExitBlock) {
List<Edge> exitEdges = loop.getExitEdges();
if (exitEdges.size() < 2) {
return true;
}
Optional<Edge> mainEdgeOpt = exitEdges.stream().filter(edge -> edge.getSource() == mainExitBlock).findFirst();
if (!mainEdgeOpt.isPresent()) {
throw new JadxRuntimeException("Not found exit edge by exit block: " + mainExitBlock);
}
Edge mainExitEdge = mainEdgeOpt.get();
BlockNode mainOutBlock = skipSyntheticSuccessor(mainExitEdge.getTarget());
for (Edge exitEdge : exitEdges) {
if (exitEdge != mainExitEdge) {
BlockNode outBlock = skipSyntheticSuccessor(exitEdge.getTarget());
// all exit paths must be same or don't cross (will be inside loop)
if (!isEqualPaths(mainOutBlock, outBlock)) {
BlockNode crossBlock = BlockUtils.getPathCross(mth, mainOutBlock, outBlock);
if (crossBlock != null) {
return false;
}
}
}
}
return true;
}
private BlockNode makeEndlessLoop(IRegion curRegion, RegionStack stack, LoopInfo loop, BlockNode loopStart) { private BlockNode makeEndlessLoop(IRegion curRegion, RegionStack stack, LoopInfo loop, BlockNode loopStart) {
LoopRegion loopRegion = new LoopRegion(curRegion, loop, null, false); LoopRegion loopRegion = new LoopRegion(curRegion, loop, null, false);
curRegion.getSubBlocks().add(loopRegion); curRegion.getSubBlocks().add(loopRegion);
...@@ -310,7 +340,7 @@ public class RegionMaker { ...@@ -310,7 +340,7 @@ public class RegionMaker {
if (exitEdges.size() == 1) { if (exitEdges.size() == 1) {
Edge exitEdge = exitEdges.get(0); Edge exitEdge = exitEdges.get(0);
BlockNode exit = exitEdge.getTarget(); BlockNode exit = exitEdge.getTarget();
if (insertBreak(stack, exit, exitEdge)) { if (insertLoopBreak(stack, loop, exit, exitEdge)) {
BlockNode nextBlock = getNextBlock(exit); BlockNode nextBlock = getNextBlock(exit);
if (nextBlock != null) { if (nextBlock != null) {
stack.addExit(nextBlock); stack.addExit(nextBlock);
...@@ -324,10 +354,10 @@ public class RegionMaker { ...@@ -324,10 +354,10 @@ public class RegionMaker {
for (BlockNode block : blocks) { for (BlockNode block : blocks) {
if (BlockUtils.isPathExists(exit, block)) { if (BlockUtils.isPathExists(exit, block)) {
stack.addExit(block); stack.addExit(block);
insertBreak(stack, block, exitEdge); insertLoopBreak(stack, loop, block, exitEdge);
out = block; out = block;
} else { } else {
insertBreak(stack, exit, exitEdge); insertLoopBreak(stack, loop, exit, exitEdge);
} }
} }
} }
...@@ -386,7 +416,7 @@ public class RegionMaker { ...@@ -386,7 +416,7 @@ public class RegionMaker {
return true; return true;
} }
private boolean insertBreak(RegionStack stack, BlockNode loopExit, Edge exitEdge) { private boolean insertLoopBreak(RegionStack stack, LoopInfo loop, BlockNode loopExit, Edge exitEdge) {
BlockNode exit = exitEdge.getTarget(); BlockNode exit = exitEdge.getTarget();
BlockNode insertBlock = null; BlockNode insertBlock = null;
boolean confirm = false; boolean confirm = false;
...@@ -425,6 +455,7 @@ public class RegionMaker { ...@@ -425,6 +455,7 @@ public class RegionMaker {
return false; return false;
} }
InsnNode breakInsn = new InsnNode(InsnType.BREAK, 0); InsnNode breakInsn = new InsnNode(InsnType.BREAK, 0);
breakInsn.addAttr(AType.LOOP, loop);
EdgeInsnAttr.addEdgeInsn(insertBlock, insertBlock.getSuccessors().get(0), breakInsn); EdgeInsnAttr.addEdgeInsn(insertBlock, insertBlock.getSuccessors().get(0), breakInsn);
stack.addExit(exit); stack.addExit(exit);
// add label to 'break' if needed // add label to 'break' if needed
......
...@@ -104,6 +104,13 @@ public class RegionMakerVisitor extends AbstractVisitor { ...@@ -104,6 +104,13 @@ public class RegionMakerVisitor extends AbstractVisitor {
if (!insnAttr.getStart().equals(last)) { if (!insnAttr.getStart().equals(last)) {
return; return;
} }
if (last instanceof BlockNode) {
BlockNode block = (BlockNode) last;
if (block.getInstructions().isEmpty()) {
block.getInstructions().add(insnAttr.getInsn());
return;
}
}
List<InsnNode> insns = Collections.singletonList(insnAttr.getInsn()); List<InsnNode> insns = Collections.singletonList(insnAttr.getInsn());
region.add(new InsnContainer(insns)); region.add(new InsnContainer(insns));
} }
......
...@@ -8,6 +8,9 @@ import java.util.Set; ...@@ -8,6 +8,9 @@ import java.util.Set;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.AttrList;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock; import jadx.core.dex.nodes.IBlock;
...@@ -91,22 +94,71 @@ public class RegionUtils { ...@@ -91,22 +94,71 @@ public class RegionUtils {
} }
/** /**
* Return true if last block in region has no successors * Return true if last block in region has no successors or jump out insn (return or break)
*/ */
public static boolean hasExitBlock(IContainer container) { public static boolean hasExitBlock(IContainer container) {
return hasExitBlock(container, container);
}
private static boolean hasExitBlock(IContainer rootContainer, IContainer container) {
if (container instanceof BlockNode) { if (container instanceof BlockNode) {
return ((BlockNode) container).getSuccessors().isEmpty(); BlockNode blockNode = (BlockNode) container;
} else if (container instanceof IBlock) { if (blockNode.getSuccessors().isEmpty()) {
return true; return true;
}
return isInsnExitContainer(rootContainer, (IBlock) container);
} else if (container instanceof IBranchRegion) {
return false;
} else if (container instanceof IBlock) {
return isInsnExitContainer(rootContainer, (IBlock) container);
} else if (container instanceof IRegion) { } else if (container instanceof IRegion) {
List<IContainer> blocks = ((IRegion) container).getSubBlocks(); List<IContainer> blocks = ((IRegion) container).getSubBlocks();
return !blocks.isEmpty() return !blocks.isEmpty()
&& hasExitBlock(blocks.get(blocks.size() - 1)); && hasExitBlock(rootContainer, blocks.get(blocks.size() - 1));
} else { } else {
throw new JadxRuntimeException(unknownContainerType(container)); throw new JadxRuntimeException(unknownContainerType(container));
} }
} }
private static boolean isInsnExitContainer(IContainer rootContainer, IBlock block) {
InsnNode lastInsn = BlockUtils.getLastInsn(block);
if (lastInsn == null) {
return false;
}
InsnType insnType = lastInsn.getType();
if (insnType == InsnType.RETURN) {
return true;
}
if (insnType == InsnType.THROW) {
// check if after throw execution can continue in current container
CatchAttr catchAttr = lastInsn.get(AType.CATCH_BLOCK);
if (catchAttr != null) {
for (ExceptionHandler handler : catchAttr.getTryBlock().getHandlers()) {
if (RegionUtils.isRegionContainsBlock(rootContainer, handler.getHandlerBlock())) {
return false;
}
}
}
return true;
}
if (insnType == InsnType.BREAK) {
AttrList<LoopInfo> loopInfoAttrList = lastInsn.get(AType.LOOP);
if (loopInfoAttrList != null) {
for (LoopInfo loopInfo : loopInfoAttrList.getList()) {
if (!RegionUtils.isRegionContainsBlock(rootContainer, loopInfo.getStart())) {
return true;
}
}
}
LoopLabelAttr loopLabelAttr = lastInsn.get(AType.LOOP_LABEL);
if (loopLabelAttr != null
&& !RegionUtils.isRegionContainsBlock(rootContainer, loopLabelAttr.getLoop().getStart())) {
return true;
}
}
return false;
}
public static boolean hasBreakInsn(IContainer container) { public static boolean hasBreakInsn(IContainer container) {
if (container instanceof IBlock) { if (container instanceof IBlock) {
return BlockUtils.checkLastInsnType((IBlock) container, InsnType.BREAK); return BlockUtils.checkLastInsnType((IBlock) container, InsnType.BREAK);
......
package jadx.tests.integration.loops; package jadx.tests.integration.loops;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import java.io.File; import java.io.File;
import org.junit.Test; import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest; import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class TestIndexedLoop extends IntegrationTest { public class TestIndexedLoop extends IntegrationTest {
......
package jadx.tests.integration.loops;
import java.io.File;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class TestNotIndexedLoop extends IntegrationTest {
public static class TestCls {
public File test(File[] files) {
File file;
if (files != null) {
int length = files.length;
if (length == 0) {
file = null;
} else {
int i = 0;
while (true) {
if (i >= length) {
file = null;
break;
}
file = files[i];
if (file.getName().equals("f")) {
break;
}
i++;
}
}
} else {
file = null;
}
if (file != null) {
file.deleteOnExit();
}
return file;
}
public void check() {
assertThat(test(null), nullValue());
assertThat(test(new File[]{}), nullValue());
File file = new File("f");
assertThat(test(new File[]{new File("a"), file}), is(file));
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, not(containsString("for (")));
assertThat(code, containsOne("while (true) {"));
}
@Test
public void testNoDebug() {
noDebugInfo();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, not(containsString("for (")));
assertThat(code, containsOne("while (true) {"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment