Commit 2a3162f8 authored by Skylot's avatar Skylot

core: don't set 'skip' flag for failed nested 'if' merge (issue #18)

parent 2063fd07
...@@ -3,35 +3,47 @@ package jadx.core.dex.regions.conditions; ...@@ -3,35 +3,47 @@ package jadx.core.dex.regions.conditions;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set; import java.util.Set;
public final class IfInfo { public final class IfInfo {
private final IfCondition condition; private final IfCondition condition;
private final Set<BlockNode> mergedBlocks = new HashSet<BlockNode>(); private final Set<BlockNode> mergedBlocks;
private final BlockNode thenBlock; private final BlockNode thenBlock;
private final BlockNode elseBlock; private final BlockNode elseBlock;
private final List<BlockNode> skipBlocks;
private BlockNode outBlock; private BlockNode outBlock;
@Deprecated @Deprecated
private BlockNode ifBlock; private BlockNode ifBlock;
public IfInfo(IfCondition condition, BlockNode thenBlock, BlockNode elseBlock) { public IfInfo(IfCondition condition, BlockNode thenBlock, BlockNode elseBlock) {
this.condition = condition; this(condition, thenBlock, elseBlock, new HashSet<BlockNode>(), new LinkedList<BlockNode>());
this.thenBlock = thenBlock;
this.elseBlock = elseBlock;
} }
public IfInfo(IfCondition condition, IfInfo info) { public IfInfo(IfCondition condition, IfInfo info) {
this(condition, info.getThenBlock(), info.getElseBlock(), info.getMergedBlocks(), info.getSkipBlocks());
}
public IfInfo(IfInfo info, BlockNode thenBlock, BlockNode elseBlock) {
this(info.getCondition(), thenBlock, elseBlock, info.getMergedBlocks(), info.getSkipBlocks());
}
private IfInfo(IfCondition condition, BlockNode thenBlock, BlockNode elseBlock,
Set<BlockNode> mergedBlocks, List<BlockNode> skipBlocks) {
this.condition = condition; this.condition = condition;
this.thenBlock = info.getThenBlock(); this.thenBlock = thenBlock;
this.elseBlock = info.getElseBlock(); this.elseBlock = elseBlock;
this.mergedBlocks.addAll(info.getMergedBlocks()); this.mergedBlocks = mergedBlocks;
this.skipBlocks = skipBlocks;
} }
public static IfInfo invert(IfInfo info) { public static IfInfo invert(IfInfo info) {
IfInfo tmpIf = new IfInfo(IfCondition.invert(info.getCondition()), IfCondition invertedCondition = IfCondition.invert(info.getCondition());
info.getElseBlock(), info.getThenBlock()); IfInfo tmpIf = new IfInfo(invertedCondition,
info.getElseBlock(), info.getThenBlock(),
info.getMergedBlocks(), info.getSkipBlocks());
tmpIf.setIfBlock(info.getIfBlock()); tmpIf.setIfBlock(info.getIfBlock());
tmpIf.getMergedBlocks().addAll(info.getMergedBlocks());
return tmpIf; return tmpIf;
} }
...@@ -59,6 +71,10 @@ public final class IfInfo { ...@@ -59,6 +71,10 @@ public final class IfInfo {
this.outBlock = outBlock; this.outBlock = outBlock;
} }
public List<BlockNode> getSkipBlocks() {
return skipBlocks;
}
public BlockNode getIfBlock() { public BlockNode getIfBlock() {
return ifBlock; return ifBlock;
} }
......
...@@ -21,6 +21,7 @@ import java.util.Set; ...@@ -21,6 +21,7 @@ import java.util.Set;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import static jadx.core.utils.BlockUtils.getNextBlock;
import static jadx.core.utils.BlockUtils.isPathExists; import static jadx.core.utils.BlockUtils.isPathExists;
public class IfMakerHelper { public class IfMakerHelper {
...@@ -38,6 +39,11 @@ public class IfMakerHelper { ...@@ -38,6 +39,11 @@ public class IfMakerHelper {
return info; return info;
} }
static IfInfo searchNestedIf(IfInfo info) {
IfInfo tmp = mergeNestedIfNodes(info);
return tmp != null ? tmp : info;
}
static IfInfo restructureIf(MethodNode mth, BlockNode block, IfInfo info) { static IfInfo restructureIf(MethodNode mth, BlockNode block, IfInfo info) {
final BlockNode thenBlock = info.getThenBlock(); final BlockNode thenBlock = info.getThenBlock();
final BlockNode elseBlock = info.getElseBlock(); final BlockNode elseBlock = info.getElseBlock();
...@@ -54,11 +60,11 @@ public class IfMakerHelper { ...@@ -54,11 +60,11 @@ public class IfMakerHelper {
return null; return null;
} }
if (badElse) { if (badElse) {
info = new IfInfo(info.getCondition(), thenBlock, null); info = new IfInfo(info, thenBlock, null);
info.setOutBlock(elseBlock); info.setOutBlock(elseBlock);
} else if (badThen) { } else if (badThen) {
info = IfInfo.invert(info); info = IfInfo.invert(info);
info = new IfInfo(info.getCondition(), elseBlock, null); info = new IfInfo(info, elseBlock, null);
info.setOutBlock(thenBlock); info.setOutBlock(thenBlock);
} else { } else {
List<BlockNode> thenSC = thenBlock.getCleanSuccessors(); List<BlockNode> thenSC = thenBlock.getCleanSuccessors();
...@@ -101,11 +107,6 @@ public class IfMakerHelper { ...@@ -101,11 +107,6 @@ public class IfMakerHelper {
return c1.size() == c2.size() && c1.containsAll(c2); return c1.size() == c2.size() && c1.containsAll(c2);
} }
static IfInfo searchNestedIf(IfInfo info) {
IfInfo tmp = mergeNestedIfNodes(info);
return tmp != null ? tmp : info;
}
static IfInfo mergeNestedIfNodes(IfInfo currentIf) { static IfInfo mergeNestedIfNodes(IfInfo currentIf) {
BlockNode curThen = currentIf.getThenBlock(); BlockNode curThen = currentIf.getThenBlock();
BlockNode curElse = currentIf.getElseBlock(); BlockNode curElse = currentIf.getElseBlock();
...@@ -181,7 +182,6 @@ public class IfMakerHelper { ...@@ -181,7 +182,6 @@ public class IfMakerHelper {
nextElse = IfInfo.invert(nextElse); nextElse = IfInfo.invert(nextElse);
return mergeTernaryConditions(currentIf, nextThen, nextElse); return mergeTernaryConditions(currentIf, nextThen, nextElse);
} }
return null; return null;
} }
...@@ -193,9 +193,10 @@ public class IfMakerHelper { ...@@ -193,9 +193,10 @@ public class IfMakerHelper {
result.getMergedBlocks().addAll(currentIf.getMergedBlocks()); result.getMergedBlocks().addAll(currentIf.getMergedBlocks());
result.getMergedBlocks().addAll(nextThen.getMergedBlocks()); result.getMergedBlocks().addAll(nextThen.getMergedBlocks());
result.getMergedBlocks().addAll(nextElse.getMergedBlocks()); result.getMergedBlocks().addAll(nextElse.getMergedBlocks());
for (BlockNode blockNode : result.getMergedBlocks()) { result.getSkipBlocks().addAll(currentIf.getSkipBlocks());
blockNode.add(AFlag.SKIP); result.getSkipBlocks().addAll(nextThen.getSkipBlocks());
} result.getSkipBlocks().addAll(nextElse.getSkipBlocks());
confirmMerge(result);
return result; return result;
} }
...@@ -210,19 +211,30 @@ public class IfMakerHelper { ...@@ -210,19 +211,30 @@ public class IfMakerHelper {
private static IfInfo mergeIfInfo(IfInfo first, IfInfo second, boolean followThenBranch) { private static IfInfo mergeIfInfo(IfInfo first, IfInfo second, boolean followThenBranch) {
Mode mergeOperation = followThenBranch ? Mode.AND : Mode.OR; Mode mergeOperation = followThenBranch ? Mode.AND : Mode.OR;
BlockNode otherPathBlock = followThenBranch ? first.getElseBlock() : first.getThenBlock();
RegionMaker.skipSimplePath(otherPathBlock);
first.getIfBlock().add(AFlag.SKIP);
second.getIfBlock().add(AFlag.SKIP);
IfCondition condition = IfCondition.merge(mergeOperation, first.getCondition(), second.getCondition()); IfCondition condition = IfCondition.merge(mergeOperation, first.getCondition(), second.getCondition());
IfInfo result = new IfInfo(condition, second); IfInfo result = new IfInfo(condition, second);
result.setIfBlock(first.getIfBlock()); result.setIfBlock(first.getIfBlock());
result.getMergedBlocks().addAll(first.getMergedBlocks()); result.getMergedBlocks().addAll(first.getMergedBlocks());
result.getMergedBlocks().addAll(second.getMergedBlocks()); result.getMergedBlocks().addAll(second.getMergedBlocks());
result.getSkipBlocks().addAll(first.getSkipBlocks());
result.getSkipBlocks().addAll(second.getSkipBlocks());
BlockNode otherPathBlock = followThenBranch ? first.getElseBlock() : first.getThenBlock();
skipSimplePath(otherPathBlock, result.getSkipBlocks());
return result; return result;
} }
static void confirmMerge(IfInfo info) {
for (BlockNode block : info.getMergedBlocks()) {
block.add(AFlag.SKIP);
}
for (BlockNode block : info.getSkipBlocks()) {
block.add(AFlag.SKIP);
}
info.getSkipBlocks().clear();
}
private static IfInfo getNextIf(IfInfo info, BlockNode block) { private static IfInfo getNextIf(IfInfo info, BlockNode block) {
if (!canSelectNext(info, block)) { if (!canSelectNext(info, block)) {
return null; return null;
...@@ -290,4 +302,13 @@ public class IfMakerHelper { ...@@ -290,4 +302,13 @@ public class IfMakerHelper {
} }
return null; return null;
} }
private static void skipSimplePath(BlockNode block, List<BlockNode> skipped) {
while (block != null
&& block.getCleanSuccessors().size() < 2
&& block.getPredecessors().size() == 1) {
skipped.add(block);
block = getNextBlock(block);
}
}
} }
...@@ -40,8 +40,10 @@ import java.util.Set; ...@@ -40,8 +40,10 @@ import java.util.Set;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import static jadx.core.dex.visitors.regions.IfMakerHelper.confirmMerge;
import static jadx.core.dex.visitors.regions.IfMakerHelper.makeIfInfo; import static jadx.core.dex.visitors.regions.IfMakerHelper.makeIfInfo;
import static jadx.core.dex.visitors.regions.IfMakerHelper.mergeNestedIfNodes; import static jadx.core.dex.visitors.regions.IfMakerHelper.mergeNestedIfNodes;
import static jadx.core.dex.visitors.regions.IfMakerHelper.searchNestedIf;
import static jadx.core.utils.BlockUtils.getBlockByOffset; import static jadx.core.utils.BlockUtils.getBlockByOffset;
import static jadx.core.utils.BlockUtils.getNextBlock; import static jadx.core.utils.BlockUtils.getNextBlock;
import static jadx.core.utils.BlockUtils.isPathExists; import static jadx.core.utils.BlockUtils.isPathExists;
...@@ -169,11 +171,9 @@ public class RegionMaker { ...@@ -169,11 +171,9 @@ public class RegionMaker {
IRegion outerRegion = stack.peekRegion(); IRegion outerRegion = stack.peekRegion();
stack.push(loopRegion); stack.push(loopRegion);
IfInfo info = makeIfInfo(loopRegion.getHeader()); IfInfo condInfo = makeIfInfo(loopRegion.getHeader());
IfInfo condInfo = mergeNestedIfNodes(info); condInfo = searchNestedIf(condInfo);
if (condInfo == null) { confirmMerge(condInfo);
condInfo = info;
}
if (!loop.getLoopBlocks().contains(condInfo.getThenBlock())) { if (!loop.getLoopBlocks().contains(condInfo.getThenBlock())) {
// invert loop condition if 'then' points to exit // invert loop condition if 'then' points to exit
condInfo = IfInfo.invert(condInfo); condInfo = IfInfo.invert(condInfo);
...@@ -450,6 +450,7 @@ public class RegionMaker { ...@@ -450,6 +450,7 @@ public class RegionMaker {
return null; return null;
} }
} }
confirmMerge(currentIf);
IfRegion ifRegion = new IfRegion(currentRegion, block); IfRegion ifRegion = new IfRegion(currentRegion, block);
ifRegion.setCondition(currentIf.getCondition()); ifRegion.setCondition(currentIf.getCondition());
...@@ -626,15 +627,6 @@ public class RegionMaker { ...@@ -626,15 +627,6 @@ public class RegionMaker {
handler.getHandlerRegion().addAttr(excHandlerAttr); handler.getHandlerRegion().addAttr(excHandlerAttr);
} }
static void skipSimplePath(BlockNode block) {
while (block != null
&& block.getCleanSuccessors().size() < 2
&& block.getPredecessors().size() == 1) {
block.add(AFlag.SKIP);
block = getNextBlock(block);
}
}
static boolean isEqualPaths(BlockNode b1, BlockNode b2) { static boolean isEqualPaths(BlockNode b1, BlockNode b2) {
if (b1 == b2) { if (b1 == b2) {
return true; return true;
......
package jadx.tests.internal.conditions;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static jadx.tests.utils.JadxMatchers.countString;
import static org.junit.Assert.assertThat;
public class TestNestedIf extends InternalJadxTest {
public static class TestCls {
private boolean a0 = false;
private int a1 = 1;
private int a2 = 2;
private int a3 = 1;
private int a4 = 2;
public boolean test1() {
if (a0) {
if (a1 == 0 || a2 == 0) {
return false;
}
} else if (a3 == 0 || a4 == 0) {
return false;
}
test1();
return true;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("if (this.a0) {"));
assertThat(code, containsOne("if (this.a1 == 0 || this.a2 == 0) {"));
assertThat(code, containsOne("} else if (this.a3 == 0 || this.a4 == 0) {"));
assertThat(code, countString(2, "return false;"));
assertThat(code, containsOne("test1();"));
assertThat(code, containsOne("return true;"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment