Commit f31c2dcd authored by Skylot's avatar Skylot

core: fix processing 'if' at loop end

parent 7699cfac
...@@ -72,7 +72,7 @@ public class AttributeStorage { ...@@ -72,7 +72,7 @@ public class AttributeStorage {
if (attrList == null) { if (attrList == null) {
return Collections.emptyList(); return Collections.emptyList();
} }
return attrList.getList(); return Collections.unmodifiableList(attrList.getList());
} }
public void remove(AFlag flag) { public void remove(AFlag flag) {
......
...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors; ...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg; import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
...@@ -194,15 +195,12 @@ public class CodeShrinker extends AbstractVisitor { ...@@ -194,15 +195,12 @@ public class CodeShrinker extends AbstractVisitor {
// continue; // continue;
// } // }
SSAVar sVar = arg.getSVar(); SSAVar sVar = arg.getSVar();
if (sVar.getAssign() == null) {
continue;
}
// allow inline only one use arg or 'this' // allow inline only one use arg or 'this'
if (sVar.getVariableUseCount() != 1 && !arg.isThis()) { if (sVar.getVariableUseCount() != 1 && !arg.isThis()) {
continue; continue;
} }
InsnNode assignInsn = sVar.getAssign().getParentInsn(); InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null) { if (assignInsn == null || assignInsn instanceof PhiInsn) {
continue; continue;
} }
int assignPos = insnList.getIndex(assignInsn); int assignPos = insnList.getIndex(assignInsn);
......
...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors.regions; ...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.IfNode; import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
...@@ -53,8 +54,8 @@ public class IfMakerHelper { ...@@ -53,8 +54,8 @@ public class IfMakerHelper {
info.setOutBlock(null); info.setOutBlock(null);
return info; return info;
} }
boolean badThen = thenBlock.contains(AFlag.LOOP_START) || !allPathsFromIf(thenBlock, info); boolean badThen = isBadBranchBlock(info, thenBlock);
boolean badElse = elseBlock.contains(AFlag.LOOP_START) || !allPathsFromIf(elseBlock, info); boolean badElse = isBadBranchBlock(info, elseBlock);
if (badThen && badElse) { if (badThen && badElse) {
LOG.debug("Stop processing blocks after 'if': {}, method: {}", info.getIfBlock(), mth); LOG.debug("Stop processing blocks after 'if': {}, method: {}", info.getIfBlock(), mth);
return null; return null;
...@@ -92,6 +93,26 @@ public class IfMakerHelper { ...@@ -92,6 +93,26 @@ public class IfMakerHelper {
return info; return info;
} }
private static boolean isBadBranchBlock(IfInfo info, BlockNode block) {
// check if block at end of loop edge
if (block.contains(AFlag.LOOP_START) && block.getPredecessors().size() == 1) {
BlockNode pred = block.getPredecessors().get(0);
if (pred.contains(AFlag.LOOP_END)) {
List<LoopInfo> startLoops = block.getAll(AType.LOOP);
List<LoopInfo> endLoops = pred.getAll(AType.LOOP);
// search for same loop
for (LoopInfo startLoop : startLoops) {
for (LoopInfo endLoop : endLoops) {
if (startLoop == endLoop) {
return true;
}
}
}
}
}
return !allPathsFromIf(block, info);
}
private static boolean allPathsFromIf(BlockNode block, IfInfo info) { private static boolean allPathsFromIf(BlockNode block, IfInfo info) {
List<BlockNode> preds = block.getPredecessors(); List<BlockNode> preds = block.getPredecessors();
Set<BlockNode> ifBlocks = info.getMergedBlocks(); Set<BlockNode> ifBlocks = info.getMergedBlocks();
......
...@@ -84,7 +84,8 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -84,7 +84,8 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
return false; return false;
} }
PhiInsn phiInsn = incrArg.getSVar().getUsedInPhi(); PhiInsn phiInsn = incrArg.getSVar().getUsedInPhi();
if (phiInsn.getArgsCount() != 2 if (phiInsn == null
|| phiInsn.getArgsCount() != 2
|| !phiInsn.getArg(1).equals(incrArg) || !phiInsn.getArg(1).equals(incrArg)
|| incrArg.getSVar().getUseCount() != 1) { || incrArg.getSVar().getUseCount() != 1) {
return false; return false;
...@@ -102,6 +103,12 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -102,6 +103,12 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
if (!usedOnlyInLoop(mth, loopRegion, arg)) { if (!usedOnlyInLoop(mth, loopRegion, arg)) {
return false; return false;
} }
// can't make loop if argument from increment instruction is assign in loop
for (InsnArg iArg : incrInsn.getArguments()) {
if (iArg.isRegister() && assignOnlyInLoop(mth, loopRegion, (RegisterArg) iArg)) {
return false;
}
}
// all checks passed // all checks passed
initInsn.add(AFlag.SKIP); initInsn.add(AFlag.SKIP);
...@@ -188,7 +195,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -188,7 +195,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
if (wrapArg != null) { if (wrapArg != null) {
wrapArg.getParentInsn().replaceArg(wrapArg, iterVar); wrapArg.getParentInsn().replaceArg(wrapArg, iterVar);
} else { } else {
LOG.debug(" Wrapped insn not found: {}, mth: {}", arrGetInsn, mth); LOG.debug(" checkArrayForEach: Wrapped insn not found: {}, mth: {}", arrGetInsn, mth);
} }
} }
return new ForEachLoop(iterVar, len.getArg(0)); return new ForEachLoop(iterVar, len.getArg(0));
...@@ -237,7 +244,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -237,7 +244,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
} }
} }
} else { } else {
LOG.warn(" Wrapped insn not found: {}, mth: {}", nextCall, mth); LOG.warn(" checkIterableForEach: Wrapped insn not found: {}, mth: {}", nextCall, mth);
return false; return false;
} }
} else { } else {
...@@ -295,6 +302,25 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor ...@@ -295,6 +302,25 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
return false; return false;
} }
private static boolean assignOnlyInLoop(MethodNode mth, LoopRegion loopRegion, RegisterArg arg) {
InsnNode assignInsn = arg.getAssignInsn();
if (assignInsn == null) {
return true;
}
if (!argInLoop(mth, loopRegion, assignInsn.getResult())) {
return false;
}
if (assignInsn instanceof PhiInsn) {
PhiInsn phiInsn = (PhiInsn) assignInsn;
for (InsnArg phiArg : phiInsn.getArguments()) {
if (!assignOnlyInLoop(mth, loopRegion, (RegisterArg) phiArg)) {
return false;
}
}
}
return true;
}
private static boolean usedOnlyInLoop(MethodNode mth, LoopRegion loopRegion, RegisterArg arg) { private static boolean usedOnlyInLoop(MethodNode mth, LoopRegion loopRegion, RegisterArg arg) {
List<RegisterArg> useList = arg.getSVar().getUseList(); List<RegisterArg> useList = arg.getSVar().getUseList();
for (RegisterArg useArg : useList) { for (RegisterArg useArg : useList) {
......
...@@ -2,8 +2,10 @@ package jadx.core.utils; ...@@ -2,8 +2,10 @@ package jadx.core.utils;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.PhiListAttr;
import jadx.core.dex.instructions.IfNode; import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg; import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.mods.TernaryInsn; import jadx.core.dex.instructions.mods.TernaryInsn;
...@@ -119,6 +121,9 @@ public class BlockUtils { ...@@ -119,6 +121,9 @@ public class BlockUtils {
} }
public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) { public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) {
if (insn instanceof PhiInsn) {
return searchBlockWithPhi(mth, (PhiInsn) insn);
}
if (insn.contains(AFlag.WRAPPED)) { if (insn.contains(AFlag.WRAPPED)) {
return getBlockByWrappedInsn(mth, insn); return getBlockByWrappedInsn(mth, insn);
} }
...@@ -130,6 +135,20 @@ public class BlockUtils { ...@@ -130,6 +135,20 @@ public class BlockUtils {
return null; return null;
} }
private static BlockNode searchBlockWithPhi(MethodNode mth, PhiInsn insn) {
for (BlockNode block : mth.getBasicBlocks()) {
PhiListAttr phiListAttr = block.get(AType.PHI_LIST);
if (phiListAttr != null) {
for (PhiInsn phiInsn : phiListAttr.getList()) {
if (phiInsn == insn) {
return block;
}
}
}
}
return null;
}
private static BlockNode getBlockByWrappedInsn(MethodNode mth, InsnNode insn) { private static BlockNode getBlockByWrappedInsn(MethodNode mth, InsnNode insn) {
for (BlockNode bn : mth.getBasicBlocks()) { for (BlockNode bn : mth.getBasicBlocks()) {
for (InsnNode bi : bn.getInstructions()) { for (InsnNode bi : bn.getInstructions()) {
......
package jadx.tests.integration.loops;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestIfInLoop2 extends IntegrationTest {
public static class TestCls {
public static void test(String str) {
int len = str.length();
int at = 0;
while (at < len) {
char c = str.charAt(at);
int endAt = at + 1;
if (c == 'A') {
while (endAt < len) {
c = str.charAt(endAt);
if (c == 'B') {
break;
}
endAt++;
}
}
at = endAt;
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, not(containsString("for (int at = 0; at < len; at = endAt) {")));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment