Commit 1a85fa8e authored by Skylot's avatar Skylot

core: fix complex conditions with mode alternation (fix #31)

parent c7b8508c
...@@ -19,10 +19,6 @@ public final class IfInfo { ...@@ -19,10 +19,6 @@ public final class IfInfo {
this(condition, thenBlock, elseBlock, new HashSet<BlockNode>(), new HashSet<BlockNode>()); this(condition, thenBlock, elseBlock, new HashSet<BlockNode>(), new HashSet<BlockNode>());
} }
public IfInfo(IfCondition condition, IfInfo info) {
this(condition, info.getThenBlock(), info.getElseBlock(), info.getMergedBlocks(), info.getSkipBlocks());
}
public IfInfo(IfInfo info, BlockNode thenBlock, BlockNode elseBlock) { public IfInfo(IfInfo info, BlockNode thenBlock, BlockNode elseBlock) {
this(info.getCondition(), thenBlock, elseBlock, info.getMergedBlocks(), info.getSkipBlocks()); this(info.getCondition(), thenBlock, elseBlock, info.getMergedBlocks(), info.getSkipBlocks());
} }
...@@ -90,6 +86,6 @@ public final class IfInfo { ...@@ -90,6 +86,6 @@ public final class IfInfo {
@Override @Override
public String toString() { public String toString() {
return "IfInfo: " + condition + ", then: " + thenBlock + ", else: " + elseBlock; return "IfInfo: then: " + thenBlock + ", else: " + elseBlock;
} }
} }
...@@ -156,7 +156,7 @@ public final class LoopRegion extends AbstractRegion { ...@@ -156,7 +156,7 @@ public final class LoopRegion extends AbstractRegion {
@Override @Override
public String baseString() { public String baseString() {
return body.baseString(); return body == null ? "-" : body.baseString();
} }
@Override @Override
......
...@@ -30,7 +30,7 @@ public class CheckRegions extends AbstractVisitor { ...@@ -30,7 +30,7 @@ public class CheckRegions extends AbstractVisitor {
return; return;
} }
// printRegion(mth, mth.getRegion(), "|"); // printRegion(mth);
// check if all blocks included in regions // check if all blocks included in regions
final Set<BlockNode> blocksInRegions = new HashSet<BlockNode>(); final Set<BlockNode> blocksInRegions = new HashSet<BlockNode>();
...@@ -93,6 +93,11 @@ public class CheckRegions extends AbstractVisitor { ...@@ -93,6 +93,11 @@ public class CheckRegions extends AbstractVisitor {
LOG.debug(" Found block: {} in regions: {}", block, regions); LOG.debug(" Found block: {} in regions: {}", block, regions);
} }
private void printRegion(MethodNode mth) {
LOG.debug("|" + mth.toString());
printRegion(mth, mth.getRegion(), "| ");
}
private void printRegion(MethodNode mth, IRegion region, String indent) { private void printRegion(MethodNode mth, IRegion region, String indent) {
LOG.debug(indent + region); LOG.debug(indent + region);
for (IContainer container : region.getSubBlocks()) { for (IContainer container : region.getSubBlocks()) {
......
...@@ -14,6 +14,7 @@ import jadx.core.dex.regions.conditions.IfCondition; ...@@ -14,6 +14,7 @@ import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.regions.conditions.IfCondition.Mode; import jadx.core.dex.regions.conditions.IfCondition.Mode;
import jadx.core.dex.regions.conditions.IfInfo; import jadx.core.dex.regions.conditions.IfInfo;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
...@@ -22,6 +23,8 @@ import java.util.Set; ...@@ -22,6 +23,8 @@ import java.util.Set;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import static jadx.core.dex.visitors.regions.RegionMaker.isEqualPaths;
import static jadx.core.dex.visitors.regions.RegionMaker.isReturnBlocks;
import static jadx.core.utils.BlockUtils.getNextBlock; import static jadx.core.utils.BlockUtils.getNextBlock;
import static jadx.core.utils.BlockUtils.isPathExists; import static jadx.core.utils.BlockUtils.isPathExists;
...@@ -117,6 +120,7 @@ public class IfMakerHelper { ...@@ -117,6 +120,7 @@ public class IfMakerHelper {
List<BlockNode> preds = block.getPredecessors(); List<BlockNode> preds = block.getPredecessors();
Set<BlockNode> ifBlocks = info.getMergedBlocks(); Set<BlockNode> ifBlocks = info.getMergedBlocks();
for (BlockNode pred : preds) { for (BlockNode pred : preds) {
pred = BlockUtils.skipSyntheticPredecessor(pred);
if (!ifBlocks.contains(pred) && !pred.contains(AFlag.LOOP_END)) { if (!ifBlocks.contains(pred) && !pred.contains(AFlag.LOOP_END)) {
return false; return false;
} }
...@@ -150,22 +154,18 @@ public class IfMakerHelper { ...@@ -150,22 +154,18 @@ public class IfMakerHelper {
// invert current node for match pattern // invert current node for match pattern
nextIf = IfInfo.invert(nextIf); nextIf = IfInfo.invert(nextIf);
} }
if (!RegionMaker.isEqualPaths(curElse, nextIf.getElseBlock()) if (!isEqualPaths(curThen, nextIf.getThenBlock())
&& !RegionMaker.isEqualPaths(curThen, nextIf.getThenBlock())) { && !isEqualPaths(curElse, nextIf.getElseBlock())) {
// complex condition, run additional checks // complex condition, run additional checks
if (checkConditionBranches(curThen, curElse) if (checkConditionBranches(curThen, curElse)
|| checkConditionBranches(curElse, curThen)) { || checkConditionBranches(curElse, curThen)) {
return null; return null;
} }
BlockNode otherBranchBlock = followThenBranch ? curElse : curThen; BlockNode otherBranchBlock = followThenBranch ? curElse : curThen;
otherBranchBlock = BlockUtils.skipSyntheticSuccessor(otherBranchBlock);
if (!isPathExists(nextIf.getIfBlock(), otherBranchBlock)) { if (!isPathExists(nextIf.getIfBlock(), otherBranchBlock)) {
return checkForTernaryInCondition(currentIf); return checkForTernaryInCondition(currentIf);
} }
if (isPathExists(nextIf.getThenBlock(), otherBranchBlock)
&& isPathExists(nextIf.getElseBlock(), otherBranchBlock)) {
// both branches paths points to one block
return null;
}
// this is nested conditions with different mode (i.e (a && b) || c), // this is nested conditions with different mode (i.e (a && b) || c),
// search next condition for merge, get null if failed // search next condition for merge, get null if failed
...@@ -175,6 +175,9 @@ public class IfMakerHelper { ...@@ -175,6 +175,9 @@ public class IfMakerHelper {
if (isInversionNeeded(currentIf, nextIf)) { if (isInversionNeeded(currentIf, nextIf)) {
nextIf = IfInfo.invert(nextIf); nextIf = IfInfo.invert(nextIf);
} }
if (!canMerge(currentIf, nextIf, followThenBranch)) {
return currentIf;
}
} else { } else {
return currentIf; return currentIf;
} }
...@@ -219,8 +222,16 @@ public class IfMakerHelper { ...@@ -219,8 +222,16 @@ public class IfMakerHelper {
} }
private static boolean isInversionNeeded(IfInfo currentIf, IfInfo nextIf) { private static boolean isInversionNeeded(IfInfo currentIf, IfInfo nextIf) {
return RegionMaker.isEqualPaths(currentIf.getElseBlock(), nextIf.getThenBlock()) return isEqualPaths(currentIf.getElseBlock(), nextIf.getThenBlock())
|| RegionMaker.isEqualPaths(currentIf.getThenBlock(), nextIf.getElseBlock()); || isEqualPaths(currentIf.getThenBlock(), nextIf.getElseBlock());
}
private static boolean canMerge(IfInfo a, IfInfo b, boolean followThenBranch) {
if (followThenBranch) {
return isEqualPaths(a.getElseBlock(), b.getElseBlock());
} else {
return isEqualPaths(a.getThenBlock(), b.getThenBlock());
}
} }
private static boolean checkConditionBranches(BlockNode from, BlockNode to) { private static boolean checkConditionBranches(BlockNode from, BlockNode to) {
...@@ -231,7 +242,17 @@ public class IfMakerHelper { ...@@ -231,7 +242,17 @@ public class IfMakerHelper {
Mode mergeOperation = followThenBranch ? Mode.AND : Mode.OR; Mode mergeOperation = followThenBranch ? Mode.AND : Mode.OR;
IfCondition condition = IfCondition.merge(mergeOperation, first.getCondition(), second.getCondition()); IfCondition condition = IfCondition.merge(mergeOperation, first.getCondition(), second.getCondition());
IfInfo result = new IfInfo(condition, second); // skip synthetic successor if both parts leads to same block
BlockNode thenBlock;
BlockNode elseBlock;
if (followThenBranch) {
thenBlock = second.getThenBlock();
elseBlock = getCrossBlock(first.getElseBlock(), second.getElseBlock());
} else {
thenBlock = getCrossBlock(first.getThenBlock(), second.getThenBlock());
elseBlock = second.getElseBlock();
}
IfInfo result = new IfInfo(condition, thenBlock, elseBlock);
result.setIfBlock(first.getIfBlock()); result.setIfBlock(first.getIfBlock());
result.merge(first, second); result.merge(first, second);
...@@ -240,6 +261,25 @@ public class IfMakerHelper { ...@@ -240,6 +261,25 @@ public class IfMakerHelper {
return result; return result;
} }
private static BlockNode getCrossBlock(BlockNode first, BlockNode second) {
if (isSameBlocks(first, second)) {
return second;
}
BlockNode firstSkip = BlockUtils.skipSyntheticSuccessor(first);
if (isSameBlocks(firstSkip, second)) {
return second;
}
BlockNode secondSkip = BlockUtils.skipSyntheticSuccessor(second);
if (isSameBlocks(firstSkip, secondSkip) || isSameBlocks(first, secondSkip)) {
return secondSkip;
}
throw new JadxRuntimeException("Unexpected merge pattern");
}
private static boolean isSameBlocks(BlockNode first, BlockNode second) {
return first == second || isReturnBlocks(first, second);
}
static void confirmMerge(IfInfo info) { static void confirmMerge(IfInfo info) {
if (info.getMergedBlocks().size() > 1) { if (info.getMergedBlocks().size() > 1) {
for (BlockNode block : info.getMergedBlocks()) { for (BlockNode block : info.getMergedBlocks()) {
......
...@@ -49,6 +49,7 @@ import static jadx.core.dex.visitors.regions.IfMakerHelper.searchNestedIf; ...@@ -49,6 +49,7 @@ import static jadx.core.dex.visitors.regions.IfMakerHelper.searchNestedIf;
import static jadx.core.utils.BlockUtils.getBlockByOffset; import static jadx.core.utils.BlockUtils.getBlockByOffset;
import static jadx.core.utils.BlockUtils.getNextBlock; import static jadx.core.utils.BlockUtils.getNextBlock;
import static jadx.core.utils.BlockUtils.isPathExists; import static jadx.core.utils.BlockUtils.isPathExists;
import static jadx.core.utils.BlockUtils.skipSyntheticSuccessor;
public class RegionMaker { public class RegionMaker {
private static final Logger LOG = LoggerFactory.getLogger(RegionMaker.class); private static final Logger LOG = LoggerFactory.getLogger(RegionMaker.class);
...@@ -811,15 +812,12 @@ public class RegionMaker { ...@@ -811,15 +812,12 @@ public class RegionMaker {
} }
private static boolean isSyntheticPath(BlockNode b1, BlockNode b2) { private static boolean isSyntheticPath(BlockNode b1, BlockNode b2) {
if (!b1.isSynthetic() || !b2.isSynthetic()) { BlockNode n1 = skipSyntheticSuccessor(b1);
return false; BlockNode n2 = skipSyntheticSuccessor(b2);
} return (n1 != b1 || n2 != b2) && isEqualPaths(n1, n2);
BlockNode n1 = getNextBlock(b1);
BlockNode n2 = getNextBlock(b2);
return isEqualPaths(n1, n2);
} }
private static boolean isReturnBlocks(BlockNode b1, BlockNode b2) { public static boolean isReturnBlocks(BlockNode b1, BlockNode b2) {
if (!b1.isReturnBlock() || !b2.isReturnBlock()) { if (!b1.isReturnBlock() || !b2.isReturnBlock()) {
return false; return false;
} }
...@@ -830,9 +828,9 @@ public class RegionMaker { ...@@ -830,9 +828,9 @@ public class RegionMaker {
} }
InsnNode i1 = b1Insns.get(0); InsnNode i1 = b1Insns.get(0);
InsnNode i2 = b2Insns.get(0); InsnNode i2 = b2Insns.get(0);
if (i1.getArgsCount() == 0 || i2.getArgsCount() == 0) { if (i1.getArgsCount() != i2.getArgsCount()) {
return false; return false;
} }
return i1.getArg(0).equals(i2.getArg(0)); return i1.getArgsCount() == 0 || i1.getArg(0).equals(i2.getArg(0));
} }
} }
...@@ -452,9 +452,19 @@ public class BlockUtils { ...@@ -452,9 +452,19 @@ public class BlockUtils {
* Return successor of synthetic block or same block otherwise. * Return successor of synthetic block or same block otherwise.
*/ */
public static BlockNode skipSyntheticSuccessor(BlockNode block) { public static BlockNode skipSyntheticSuccessor(BlockNode block) {
if (block.isSynthetic() && !block.getSuccessors().isEmpty()) { if (block.isSynthetic() && block.getSuccessors().size() == 1) {
return block.getSuccessors().get(0); return block.getSuccessors().get(0);
} }
return block; return block;
} }
/**
* Return predecessor of synthetic block or same block otherwise.
*/
public static BlockNode skipSyntheticPredecessor(BlockNode block) {
if (block.isSynthetic() && block.getPredecessors().size() == 1) {
return block.getPredecessors().get(0);
}
return block;
}
} }
package jadx.tests.integration.conditions;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
public class TestConditions16 extends IntegrationTest {
public static class TestCls {
private static boolean test(int a, int b) {
return a < 0 || b % 2 != 0 && a > 28 || b < 0;
}
public void check() {
assertTrue(test(-1, 1));
assertTrue(test(1, -1));
assertTrue(test(29, 3));
assertFalse(test(2, 2));
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
// assertThat(code, containsOne("return a < 0 || (b % 2 != 0 && a > 28) || b < 0;"));
assertThat(code, containsOne("return a < 0 || ((b % 2 != 0 && a > 28) || b < 0);"));
}
}
package jadx.tests.integration.loops;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
public class TestIfInLoop3 extends IntegrationTest {
public static class TestCls {
static boolean[][] occupied = new boolean[70][70];
static boolean placingStone = true;
private static boolean test(int xx, int yy) {
int[] extraArray = new int[]{10, 45, 50, 50, 20, 20};
if (extraArray != null && placingStone) {
for (int i = 0; i < extraArray.length; i += 2) {
int tX;
int tY;
if (yy % 2 == 0) {
if (extraArray[i + 1] % 2 == 0) {
tX = xx + extraArray[i];
} else {
tX = extraArray[i] + xx - 1;
}
tY = yy + extraArray[i + 1];
} else {
tX = xx + extraArray[i];
tY = yy + extraArray[i + 1];
}
if (tX < 0 || tY < 0 || tY % 2 != 0 && tX > 28 || tY > 70
|| occupied[tX][tY]) {
return false;
}
}
}
return true;
}
public void check() {
assertTrue(test(14, 2));
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("for (int i = 0; i < extraArray.length; i += 2) {"));
assertThat(code, containsOne("if (extraArray != null && placingStone) {"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment