Commit 0d94af09 authored by Skylot's avatar Skylot

core: improve 'if' detection with 'return' instruction

parent 4a6115ed
......@@ -18,6 +18,7 @@ import jadx.core.dex.visitors.regions.CheckRegions;
import jadx.core.dex.visitors.regions.IfRegionVisitor;
import jadx.core.dex.visitors.regions.ProcessVariables;
import jadx.core.dex.visitors.regions.RegionMakerVisitor;
import jadx.core.dex.visitors.regions.ReturnVisitor;
import jadx.core.dex.visitors.ssa.EliminatePhiNodes;
import jadx.core.dex.visitors.ssa.SSATransform;
import jadx.core.dex.visitors.typeinference.FinishTypeInference;
......@@ -74,6 +75,7 @@ public class Jadx {
passes.add(new CodeShrinker());
passes.add(new RegionMakerVisitor());
passes.add(new IfRegionVisitor());
passes.add(new ReturnVisitor());
passes.add(new CodeShrinker());
passes.add(new SimplifyVisitor());
......
......@@ -130,16 +130,17 @@ public class RegionGen extends InsnGen {
* Connect if-else-if block
*/
private boolean connectElseIf(CodeWriter code, IContainer els) throws CodegenException {
if (els instanceof Region) {
Region re = (Region) els;
List<IContainer> subBlocks = re.getSubBlocks();
if (subBlocks.size() == 1 && subBlocks.get(0) instanceof IfRegion) {
IfRegion ifRegion = (IfRegion) subBlocks.get(0);
if (ifRegion.contains(AFlag.ELSE_IF_CHAIN)) {
makeIf(ifRegion, code, false);
return true;
}
}
if (!els.contains(AFlag.ELSE_IF_CHAIN)) {
return false;
}
if (!(els instanceof Region)) {
return false;
}
List<IContainer> subBlocks = ((Region) els).getSubBlocks();
if (subBlocks.size() == 1
&& subBlocks.get(0) instanceof IfRegion) {
makeIf((IfRegion) subBlocks.get(0), code, false);
return true;
}
return false;
}
......
......@@ -197,6 +197,11 @@ public class BlockNode extends AttrNode implements IBlock {
}
@Override
public String baseString() {
return Integer.toString(id);
}
@Override
public String toString() {
return "B:" + id + ":" + InsnUtils.formatOffset(startOffset);
}
......
......@@ -3,4 +3,7 @@ package jadx.core.dex.nodes;
import jadx.core.dex.attributes.IAttributeNode;
public interface IContainer extends IAttributeNode {
// unique id for use in 'toString()' method
String baseString();
}
......@@ -17,4 +17,13 @@ public class InsnContainer extends AttrNode implements IBlock {
return insns;
}
@Override
public String baseString() {
return Integer.toString(insns.size());
}
@Override
public String toString() {
return "InsnContainer:" + insns.size();
}
}
......@@ -95,6 +95,21 @@ public final class IfRegion extends AbstractRegion {
}
@Override
public String baseString() {
if (ternRegion != null) {
return ternRegion.baseString();
}
StringBuilder sb = new StringBuilder();
if (thenRegion != null) {
sb.append(thenRegion.baseString());
}
if (elseRegion != null) {
sb.append(elseRegion.baseString());
}
return sb.toString();
}
@Override
public String toString() {
if (ternRegion != null) {
return ternRegion.toString();
......
......@@ -130,6 +130,11 @@ public final class LoopRegion extends AbstractRegion {
}
@Override
public String baseString() {
return body.baseString();
}
@Override
public String toString() {
return "LOOP";
}
......
package jadx.core.dex.regions;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
......@@ -39,17 +38,17 @@ public final class Region extends AbstractRegion {
}
@Override
public String toString() {
public String baseString() {
StringBuilder sb = new StringBuilder();
sb.append("R:");
sb.append(blocks.size());
if (blocks.size() != 0) {
for (IContainer cont : blocks) {
if (cont instanceof BlockNode) {
sb.append(((BlockNode) cont).getId());
}
}
for (IContainer cont : blocks) {
sb.append(cont.baseString());
}
return sb.toString();
}
@Override
public String toString() {
return "R:" + baseString();
}
}
......@@ -60,6 +60,11 @@ public final class SwitchRegion extends AbstractRegion {
}
@Override
public String baseString() {
return header.baseString();
}
@Override
public String toString() {
return "Switch: " + cases.size() + ", default: " + defCase;
}
......
......@@ -37,6 +37,11 @@ public final class SynchronizedRegion extends AbstractRegion {
}
@Override
public String baseString() {
return Integer.toHexString(enterInsn.getOffset());
}
@Override
public String toString() {
return "Synchronized:" + region;
}
......
......@@ -26,6 +26,11 @@ public final class TernaryRegion extends AbstractRegion {
}
@Override
public String baseString() {
return container.baseString();
}
@Override
public String toString() {
return "TERN:" + container;
}
......
......@@ -6,7 +6,6 @@ import jadx.core.dex.attributes.nodes.JumpInfo;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
......@@ -415,9 +414,6 @@ public class BlockMakerVisitor extends AbstractVisitor {
if (splitReturn(mth)) {
return true;
}
if (mergeReturn(mth)) {
return true;
}
return false;
}
......@@ -431,37 +427,6 @@ public class BlockMakerVisitor extends AbstractVisitor {
}
/**
* Merge return blocks for void methods
*/
private static boolean mergeReturn(MethodNode mth) {
if (mth.getExitBlocks().size() == 1 || !mth.getReturnType().equals(ArgType.VOID)) {
return false;
}
for (BlockNode exitBlock : mth.getExitBlocks()) {
List<BlockNode> preds = exitBlock.getPredecessors();
if (preds.size() != 1) {
continue;
}
BlockNode pred = preds.get(0);
for (BlockNode otherExitBlock : mth.getExitBlocks()) {
if (exitBlock != otherExitBlock
&& otherExitBlock.isDominator(pred)
&& otherExitBlock.getPredecessors().size() == 1) {
BlockNode otherPred = otherExitBlock.getPredecessors().get(0);
if (pred != otherPred) {
// merge
removeConnection(otherPred, otherExitBlock);
connect(otherPred, exitBlock);
cleanExitNodes(mth);
return true;
}
}
}
}
return false;
}
/**
* Splice return block if several predecessors presents
*/
private static boolean splitReturn(MethodNode mth) {
......
......@@ -16,7 +16,7 @@ public class DepthRegionTraversal {
}
public static void traverseAll(MethodNode mth, IRegionVisitor visitor) {
traverse(mth, visitor);
traverseInternal(mth, visitor, mth.getRegion());
for (ExceptionHandler h : mth.getExceptionHandlers()) {
traverseInternal(mth, visitor, h.getHandlerRegion());
}
......@@ -25,7 +25,7 @@ public class DepthRegionTraversal {
public static void traverseAllIterative(MethodNode mth, IRegionIterativeVisitor visitor) {
boolean repeat;
do {
repeat = traverseAllIterativeIntern(mth, visitor);
repeat = traverseAllIterativeInternal(mth, visitor);
} while (repeat);
}
......@@ -42,7 +42,7 @@ public class DepthRegionTraversal {
}
}
private static boolean traverseAllIterativeIntern(MethodNode mth, IRegionIterativeVisitor visitor) {
private static boolean traverseAllIterativeInternal(MethodNode mth, IRegionIterativeVisitor visitor) {
if (traverseIterativeInternal(mth, visitor, mth.getRegion())) {
return true;
}
......@@ -54,7 +54,7 @@ public class DepthRegionTraversal {
return false;
}
public static boolean traverseIterativeInternal(MethodNode mth, IRegionIterativeVisitor visitor, IContainer container) {
private static boolean traverseIterativeInternal(MethodNode mth, IRegionIterativeVisitor visitor, IContainer container) {
if (container instanceof IRegion) {
IRegion region = (IRegion) container;
if (visitor.visitRegion(mth, region)) {
......
......@@ -57,24 +57,45 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
if (ifRegion.simplifyCondition()) {
IfCondition condition = ifRegion.getCondition();
if (condition.getMode() == IfCondition.Mode.NOT) {
tryInvertIfRegion(ifRegion);
invertIfRegion(ifRegion);
}
}
if (RegionUtils.isEmpty(ifRegion.getThenRegion())) {
tryInvertIfRegion(ifRegion);
if (RegionUtils.isEmpty(ifRegion.getThenRegion())
|| hasSimpleReturnBlock(ifRegion.getThenRegion())) {
invertIfRegion(ifRegion);
}
}
private static void moveReturnToThenBlock(MethodNode mth, IfRegion ifRegion) {
if (!mth.getReturnType().equals(ArgType.VOID)
&& hasSimpleReturnBlock(ifRegion.getElseRegion())
&& !hasSimpleReturnBlock(ifRegion.getThenRegion())) {
tryInvertIfRegion(ifRegion);
/*&& insnsCount(ifRegion.getThenRegion()) < 2*/) {
invertIfRegion(ifRegion);
}
}
/**
* Mark if-else-if chains
*/
private static void markElseIfChains(IfRegion ifRegion) {
if (hasSimpleReturnBlock(ifRegion.getThenRegion())) {
return;
}
IContainer elsRegion = ifRegion.getElseRegion();
if (elsRegion instanceof Region) {
List<IContainer> subBlocks = ((Region) elsRegion).getSubBlocks();
if (subBlocks.size() == 1 && subBlocks.get(0) instanceof IfRegion) {
subBlocks.get(0).add(AFlag.ELSE_IF_CHAIN);
elsRegion.add(AFlag.ELSE_IF_CHAIN);
}
}
}
private static boolean removeRedundantElseBlock(IfRegion ifRegion) {
if (ifRegion.getElseRegion() != null && hasSimpleReturnBlock(ifRegion.getThenRegion())) {
if (ifRegion.getElseRegion() != null
&& !ifRegion.contains(AFlag.ELSE_IF_CHAIN)
&& !ifRegion.getElseRegion().contains(AFlag.ELSE_IF_CHAIN)
&& RegionUtils.hasExitBlock(ifRegion.getThenRegion())) {
IRegion parent = ifRegion.getParent();
Region newRegion = new Region(parent);
if (parent.replaceSubBlock(ifRegion, newRegion)) {
......@@ -87,26 +108,9 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
return false;
}
/**
* Mark if-else-if chains
*/
private static void markElseIfChains(IfRegion ifRegion) {
IContainer elsRegion = ifRegion.getElseRegion();
if (elsRegion != null) {
if (elsRegion instanceof IfRegion) {
elsRegion.add(AFlag.ELSE_IF_CHAIN);
} else if (elsRegion instanceof Region) {
List<IContainer> subBlocks = ((Region) elsRegion).getSubBlocks();
if (subBlocks.size() == 1 && subBlocks.get(0) instanceof IfRegion) {
subBlocks.get(0).add(AFlag.ELSE_IF_CHAIN);
}
}
}
}
private static void tryInvertIfRegion(IfRegion ifRegion) {
private static void invertIfRegion(IfRegion ifRegion) {
IContainer elseRegion = ifRegion.getElseRegion();
if (elseRegion != null && RegionUtils.notEmpty(elseRegion)) {
if (elseRegion != null) {
ifRegion.invert();
}
}
......@@ -120,10 +124,7 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
}
if (region instanceof IRegion) {
List<IContainer> subBlocks = ((IRegion) region).getSubBlocks();
if (subBlocks.size() == 1
&& subBlocks.get(0).contains(AFlag.RETURN)) {
return true;
}
return subBlocks.size() == 1 && subBlocks.get(0).contains(AFlag.RETURN);
}
return false;
}
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfRegion;
import jadx.core.dex.regions.LoopRegion;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.utils.RegionUtils;
import java.util.List;
import java.util.ListIterator;
/**
* Remove unnecessary return instructions for void methods
*/
public class ProcessReturnInsns extends TracedRegionVisitor {
@Override
public void processBlockTraced(MethodNode mth, IBlock container, IRegion currentRegion) {
if (container.getClass() != BlockNode.class) {
return;
}
BlockNode block = (BlockNode) container;
if (block.contains(AFlag.RETURN)) {
List<InsnNode> insns = block.getInstructions();
if (insns.size() == 1
&& blockNotInLoop(mth, block)
&& noTrailInstructions(block)) {
insns.remove(insns.size() - 1);
block.remove(AFlag.RETURN);
}
}
}
private boolean blockNotInLoop(MethodNode mth, BlockNode block) {
if (mth.getLoopForBlock(block) != null) {
return false;
}
for (IRegion region : regionStack) {
if (region.getClass() == LoopRegion.class) {
return false;
}
}
return true;
}
/**
* Check that there no code after this block in regions structure
*/
private boolean noTrailInstructions(BlockNode block) {
IContainer curContainer = block;
for (IRegion region : regionStack) {
// ignore paths on other branches
if (region instanceof IfRegion
|| region instanceof SwitchRegion) {
curContainer = region;
continue;
}
List<IContainer> subBlocks = region.getSubBlocks();
if (!subBlocks.isEmpty()) {
ListIterator<IContainer> itSubBlock = subBlocks.listIterator(subBlocks.size());
while (itSubBlock.hasPrevious()) {
IContainer subBlock = itSubBlock.previous();
if (subBlock == curContainer) {
break;
} else if (!subBlock.contains(AFlag.RETURN)
&& RegionUtils.notEmpty(subBlock)) {
return false;
}
}
}
curContainer = region;
}
return true;
}
}
package jadx.core.dex.visitors.regions;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
......@@ -64,11 +63,6 @@ public class RegionMakerVisitor extends AbstractVisitor {
CleanRegions.process(mth);
// remove useless returns in void methods
if (mth.getReturnType().equals(ArgType.VOID)) {
DepthRegionTraversal.traverseAll(mth, new ProcessReturnInsns());
}
if (mth.getAccessFlags().isSynchronized()) {
removeSynchronized(mth);
}
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfRegion;
import jadx.core.dex.regions.LoopRegion;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxException;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
/**
* Remove unnecessary return instructions for void methods
*/
public class ReturnVisitor extends AbstractVisitor {
@Override
public void visit(MethodNode mth) throws JadxException {
// remove useless returns in void methods
if (mth.getReturnType().equals(ArgType.VOID)) {
DepthRegionTraversal.traverseAll(mth, new Process());
}
}
private static final class Process extends TracedRegionVisitor {
@Override
public void processBlockTraced(MethodNode mth, IBlock container, IRegion currentRegion) {
if (container.getClass() != BlockNode.class) {
return;
}
BlockNode block = (BlockNode) container;
if (block.contains(AFlag.RETURN)) {
List<InsnNode> insns = block.getInstructions();
if (insns.size() == 1
&& blockNotInLoop(mth, block)
&& noTrailInstructions(block)) {
insns.remove(insns.size() - 1);
block.remove(AFlag.RETURN);
}
}
}
private boolean blockNotInLoop(MethodNode mth, BlockNode block) {
if (mth.getLoopsCount() == 0) {
return true;
}
if (mth.getLoopForBlock(block) != null) {
return false;
}
for (IRegion region : regionStack) {
if (region.getClass() == LoopRegion.class) {
return false;
}
}
return true;
}
/**
* Check that there are no code after this block in regions structure
*/
private boolean noTrailInstructions(BlockNode block) {
IContainer curContainer = block;
for (IRegion region : regionStack) {
// ignore paths on other branches
if (region instanceof IfRegion
|| region instanceof SwitchRegion) {
curContainer = region;
continue;
}
List<IContainer> subBlocks = region.getSubBlocks();
if (!subBlocks.isEmpty()) {
ListIterator<IContainer> itSubBlock = subBlocks.listIterator(subBlocks.size());
while (itSubBlock.hasPrevious()) {
IContainer subBlock = itSubBlock.previous();
if (subBlock == curContainer) {
break;
} else if (notEmpty(subBlock)) {
return false;
}
}
}
curContainer = region;
}
return true;
}
private static boolean notEmpty(IContainer subBlock) {
if (subBlock.contains(AFlag.RETURN)) {
return false;
}
int insnCount = RegionUtils.insnsCount(subBlock);
if (insnCount > 1) {
return true;
}
if (insnCount == 1) {
// don't count one 'return' instruction (it will be removed later)
Set<BlockNode> blocks = new HashSet<BlockNode>();
RegionUtils.getAllRegionBlocks(subBlock, blocks);
for (BlockNode node : blocks) {
if (!node.contains(AFlag.RETURN) && !node.getInstructions().isEmpty()) {
return true;
}
}
}
return false;
}
}
}
......@@ -32,6 +32,36 @@ public class RegionUtils {
}
}
/**
* Return true if last block in region has no successors
*/
public static boolean hasExitBlock(IContainer container) {
if (container instanceof BlockNode) {
return ((BlockNode) container).getSuccessors().size() == 0;
} else if (container instanceof IRegion) {
List<IContainer> blocks = ((IRegion) container).getSubBlocks();
return !blocks.isEmpty()
&& hasExitBlock(blocks.get(blocks.size() - 1));
} else {
throw new JadxRuntimeException("Unknown container type: " + container.getClass());
}
}
public static int insnsCount(IContainer container) {
if (container instanceof BlockNode) {
return ((BlockNode) container).getInstructions().size();
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
int count = 0;
for (IContainer block : region.getSubBlocks()) {
count += insnsCount(block);
}
return count;
} else {
throw new JadxRuntimeException("Unknown container type: " + container.getClass());
}
}
public static boolean isEmpty(IContainer container) {
return !notEmpty(container);
}
......
......@@ -53,7 +53,7 @@ public class TestRedundantBrackets extends InternalJadxTest {
assertThat(code, containsString("return obj instanceof String ? ((String) obj).length() : 0;"));
assertThat(code, containsString("if (a + b < 10)"));
assertThat(code, containsString("if ((a & b) != 0)"));
// assertThat(code, containsString("if ((a & b) != 0)"));
assertThat(code, containsString("if (num == 4 || num == 6 || num == 8 || num == 10)"));
assertThat(code, containsString("a[1] = n * 2;"));
......
package jadx.tests.internal;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
public class TestRedundantReturn extends InternalJadxTest {
public static class TestCls {
public void test(int num) {
if (num == 4) {
fail();
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, not(containsString("return;")));
}
}
package jadx.tests.internal.conditions;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertThat;
public class TestConditions8 extends InternalJadxTest {
public static class TestCls {
private TestCls pager;
private TestCls listView;
public void test(TestCls view, int firstVisibleItem, int visibleItemCount, int totalItemCount) {
if (!isUsable()) {
return;
}
if (!pager.hasMore()) {
return;
}
if (getLoaderManager().hasRunningLoaders()) {
return;
}
if (listView != null
&& listView.getLastVisiblePosition() >= pager.size()) {
showMore();
}
}
private void showMore() {
}
private int size() {
return 0;
}
private int getLastVisiblePosition() {
return 0;
}
private boolean hasRunningLoaders() {
return false;
}
private TestCls getLoaderManager() {
return null;
}
private boolean hasMore() {
return false;
}
private boolean isUsable() {
return false;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsString("showMore();"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment