Commit 8a4ec47b authored by Skylot's avatar Skylot

core: support break with label for simple cases

parent d2811263
......@@ -3,6 +3,7 @@ package jadx.core.codegen;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.FieldReplaceAttr;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.attributes.nodes.MethodInlineAttr;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.info.FieldInfo;
......@@ -289,6 +290,10 @@ public class InsnGen {
case BREAK:
code.add("break");
LoopLabelAttr labelAttr = insn.get(AType.LOOP_LABEL);
if (labelAttr != null) {
code.add(' ').add(mgen.getNameGen().getLoopLabel(labelAttr));
}
break;
case CONTINUE:
......
......@@ -2,6 +2,7 @@ package jadx.core.codegen;
import jadx.core.Consts;
import jadx.core.deobf.NameMapper;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.args.ArgType;
......@@ -76,6 +77,13 @@ public class NameGen {
return name;
}
// TODO: avoid name collision with variables names
public String getLoopLabel(LoopLabelAttr attr) {
String name = "loop" + attr.getLoop().getId();
varNames.add(name);
return name;
}
private String getUniqueVarName(String name) {
String r = name;
int i = 2;
......
......@@ -4,6 +4,7 @@ import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.DeclareVariablesAttr;
import jadx.core.dex.attributes.nodes.ForceReturnAttr;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.SwitchNode;
......@@ -164,6 +165,10 @@ public class RegionGen extends InsnGen {
}
}
}
LoopLabelAttr labelAttr = region.getInfo().getStart().get(AType.LOOP_LABEL);
if (labelAttr != null) {
code.startLine(mgen.getNameGen().getLoopLabel(labelAttr)).add(':');
}
IfCondition condition = region.getCondition();
if (condition == null) {
......
......@@ -10,6 +10,7 @@ import jadx.core.dex.attributes.nodes.ForceReturnAttr;
import jadx.core.dex.attributes.nodes.JadxErrorAttr;
import jadx.core.dex.attributes.nodes.JumpInfo;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.attributes.nodes.MethodInlineAttr;
import jadx.core.dex.attributes.nodes.PhiListAttr;
import jadx.core.dex.attributes.nodes.SourceFileAttr;
......@@ -29,6 +30,8 @@ public class AType<T extends IAttribute> {
private AType() {
}
public static final int FIELDS_COUNT = 18;
public static final AType<AttrList<JumpInfo>> JUMP = new AType<AttrList<JumpInfo>>();
public static final AType<AttrList<LoopInfo>> LOOP = new AType<AttrList<LoopInfo>>();
......@@ -47,4 +50,5 @@ public class AType<T extends IAttribute> {
public static final AType<PhiListAttr> PHI_LIST = new AType<PhiListAttr>();
public static final AType<SourceFileAttr> SOURCE_FILE = new AType<SourceFileAttr>();
public static final AType<DeclareVariablesAttr> DECLARE_VARIABLES = new AType<DeclareVariablesAttr>();
public static final AType<LoopLabelAttr> LOOP_LABEL = new AType<LoopLabelAttr>();
}
......@@ -7,7 +7,7 @@ import jadx.core.utils.Utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
......@@ -24,7 +24,7 @@ public class AttributeStorage {
public AttributeStorage() {
flags = EnumSet.noneOf(AFlag.class);
attributes = new HashMap<AType<?>, IAttribute>(2);
attributes = new IdentityHashMap<AType<?>, IAttribute>(AType.FIELDS_COUNT);
}
public void add(AFlag flag) {
......
......@@ -17,6 +17,9 @@ public class LoopInfo {
private final BlockNode end;
private final Set<BlockNode> loopBlocks;
private int id;
private LoopInfo parentLoop;
public LoopInfo(BlockNode start, BlockNode end) {
this.start = start;
this.end = end;
......@@ -69,8 +72,24 @@ public class LoopInfo {
return edges;
}
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public LoopInfo getParentLoop() {
return parentLoop;
}
public void setParentLoop(LoopInfo parentLoop) {
this.parentLoop = parentLoop;
}
@Override
public String toString() {
return "LOOP: " + start + "->" + end;
return "LOOP:" + id + ": " + start + "->" + end;
}
}
package jadx.core.dex.attributes.nodes;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.IAttribute;
public class LoopLabelAttr implements IAttribute {
private final LoopInfo loop;
public LoopLabelAttr(LoopInfo loop) {
this.loop = loop;
}
public LoopInfo getLoop() {
return loop;
}
@Override
public AType<LoopLabelAttr> getType() {
return AType.LOOP_LABEL;
}
@Override
public String toString() {
return "LOOP_LABEL: " + loop;
}
}
......@@ -406,10 +406,14 @@ public class MethodNode extends LineAttrNode implements ILoadable {
if (loops.isEmpty()) {
loops = new ArrayList<LoopInfo>(5);
}
loop.setId(loops.size());
loops.add(loop);
}
public LoopInfo getLoopForBlock(BlockNode block) {
if (loops.isEmpty()) {
return null;
}
for (LoopInfo loop : loops) {
if (loop.getLoopBlocks().contains(block)) {
return loop;
......@@ -418,10 +422,27 @@ public class MethodNode extends LineAttrNode implements ILoadable {
return null;
}
public List<LoopInfo> getAllLoopsForBlock(BlockNode block) {
if (loops.isEmpty()) {
return Collections.emptyList();
}
List<LoopInfo> list = new ArrayList<LoopInfo>(loops.size());
for (LoopInfo loop : loops) {
if (loop.getLoopBlocks().contains(block)) {
list.add(loop);
}
}
return list;
}
public int getLoopsCount() {
return loops.size();
}
public Iterable<LoopInfo> getLoops() {
return loops;
}
public ExceptionHandler addExceptionHandler(ExceptionHandler handler) {
if (exceptionHandlers.isEmpty()) {
exceptionHandlers = new ArrayList<ExceptionHandler>(2);
......
package jadx.core.dex.regions.loops;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
......@@ -15,6 +16,7 @@ import java.util.List;
public final class LoopRegion extends AbstractRegion {
private final LoopInfo info;
// loop header contains one 'if' insn, equals null for infinite loop
private IfCondition condition;
private final BlockNode conditionBlock;
......@@ -25,13 +27,18 @@ public final class LoopRegion extends AbstractRegion {
private LoopType type;
public LoopRegion(IRegion parent, BlockNode header, boolean reversed) {
public LoopRegion(IRegion parent, LoopInfo info, BlockNode header, boolean reversed) {
super(parent);
this.info = info;
this.conditionBlock = header;
this.condition = IfCondition.fromIfBlock(header);
this.conditionAtEnd = reversed;
}
public LoopInfo getInfo() {
return info;
}
public IfCondition getCondition() {
return condition;
}
......
......@@ -190,6 +190,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
}
computeDominanceFrontier(mth);
registerLoops(mth);
processNestedLoops(mth);
}
private static BlockNode getBlock(int offset, Map<Integer, BlockNode> blocksMap) {
......@@ -357,15 +358,40 @@ public class BlockMakerVisitor extends AbstractVisitor {
private static void registerLoops(MethodNode mth) {
for (BlockNode block : mth.getBasicBlocks()) {
List<LoopInfo> loops = block.getAll(AType.LOOP);
if (block.contains(AFlag.LOOP_START)) {
for (LoopInfo loop : loops) {
for (LoopInfo loop : block.getAll(AType.LOOP)) {
mth.registerLoop(loop);
}
}
}
}
private static void processNestedLoops(MethodNode mth) {
if (mth.getLoopsCount() == 0) {
return;
}
for (LoopInfo outLoop : mth.getLoops()) {
for (LoopInfo innerLoop : mth.getLoops()) {
if (outLoop == innerLoop) {
continue;
}
if (outLoop.getLoopBlocks().containsAll(innerLoop.getLoopBlocks())) {
LoopInfo parentLoop = innerLoop.getParentLoop();
if (parentLoop != null) {
if (parentLoop.getLoopBlocks().containsAll(outLoop.getLoopBlocks())) {
outLoop.setParentLoop(parentLoop);
innerLoop.setParentLoop(outLoop);
} else {
parentLoop.setParentLoop(outLoop);
}
} else {
innerLoop.setParentLoop(outLoop);
}
}
}
}
}
private static boolean modifyBlocksTree(MethodNode mth) {
for (BlockNode block : mth.getBasicBlocks()) {
if (block.getPredecessors().isEmpty() && block != mth.getEnterBlock()) {
......
......@@ -4,6 +4,7 @@ import jadx.core.Consts;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.SwitchNode;
......@@ -13,12 +14,12 @@ import jadx.core.dex.nodes.Edge;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.IfInfo;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.SynchronizedRegion;
import jadx.core.dex.regions.conditions.IfInfo;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.trycatch.ExcHandlerAttr;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.TryCatchBlock;
......@@ -184,7 +185,7 @@ public class RegionMaker {
if (exitBlocks.size() > 0) {
BlockNode loopExit = condInfo.getElseBlock();
if (loopExit != null) {
// add 'break' instruction before path cross between main loop exit and subexit
// add 'break' instruction before path cross between main loop exit and sub-exit
for (Edge exitEdge : loop.getExitEdges()) {
if (!exitBlocks.contains(exitEdge.getSource())) {
continue;
......@@ -245,7 +246,12 @@ public class RegionMaker {
|| block.getInstructions().get(0).getType() != InsnType.IF) {
continue;
}
LoopRegion loopRegion = new LoopRegion(curRegion, block, block == loop.getEnd());
List<LoopInfo> loops = block.getAll(AType.LOOP);
if (!loops.isEmpty() && loops.get(0) != loop) {
// skip nested loop condition
continue;
}
LoopRegion loopRegion = new LoopRegion(curRegion, loop, block, block == loop.getEnd());
boolean found;
if (block == loop.getStart() || block == loop.getEnd()
|| BlockUtils.isEmptySimplePath(loop.getStart(), block)) {
......@@ -266,7 +272,7 @@ public class RegionMaker {
}
private BlockNode makeEndlessLoop(IRegion curRegion, RegionStack stack, LoopInfo loop, BlockNode loopStart) {
LoopRegion loopRegion = new LoopRegion(curRegion, null, false);
LoopRegion loopRegion = new LoopRegion(curRegion, loop, null, false);
curRegion.getSubBlocks().add(loopRegion);
loopStart.remove(AType.LOOP);
......@@ -332,8 +338,22 @@ public class RegionMaker {
if (prev != null && isPathExists(loopExit, exit)) {
// found cross
if (canInsertBreak(exit)) {
prev.getInstructions().add(new InsnNode(InsnType.BREAK, 0));
InsnNode breakInsn = new InsnNode(InsnType.BREAK, 0);
prev.getInstructions().add(breakInsn);
stack.addExit(exit);
// add label to 'break' if needed
List<LoopInfo> loops = mth.getAllLoopsForBlock(exitEdge.getSource());
if (loops.size() >= 2) {
// find parent loop
for (LoopInfo loop : loops) {
if (loop.getParentLoop() == null) {
LoopLabelAttr labelAttr = new LoopLabelAttr(loop);
breakInsn.addAttr(labelAttr);
loop.getStart().addAttr(labelAttr);
break;
}
}
}
}
return;
}
......
package jadx.tests.integration.loops;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import java.lang.reflect.Method;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
public class TestBreakWithLabel extends IntegrationTest {
public static class TestCls {
public boolean test(int[][] arr, int b) {
boolean found = false;
loop0:
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
if (arr[i][j] == b) {
found = true;
break loop0;
}
}
}
System.out.println("found: " + found);
return found;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("loop0:"));
assertThat(code, containsOne("break loop0;"));
Method test = getReflectMethod("test", int[][].class, int.class);
int[][] testArray = {{1, 2}, {3, 4}};
assertTrue((Boolean) invoke(test, testArray, 3));
assertFalse((Boolean) invoke(test, testArray, 5));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment