Commit 5cbf71bd authored by Skylot's avatar Skylot

core: remove unnecessary return instructions for void methods

parent a85d382e
......@@ -437,6 +437,10 @@ public class MethodNode extends LineAttrNode implements ILoadable {
return null;
}
public int getLoopsCount() {
return loops.size();
}
public ExceptionHandler addExceptionHandler(ExceptionHandler handler) {
if (exceptionHandlers == null) {
exceptionHandlers = new ArrayList<ExceptionHandler>(2);
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException;
import java.util.List;
public class PostRegionVisitor extends AbstractVisitor {
@Override
public void visit(MethodNode mth) throws JadxException {
if (mth.isNoCode() || mth.getRegion() == null)
IContainer startRegion = mth.getRegion();
if (mth.isNoCode() || startRegion == null) {
return;
DepthRegionTraverser.traverse(mth, new MarkTryCatchRegions(mth), mth.getRegion());
DepthRegionTraverser.traverse(mth, new FinishRegions(), mth.getRegion());
removeReturn(mth);
}
/**
* Remove useless return at end
*/
private void removeReturn(MethodNode mth) {
if (!mth.getReturnType().equals(ArgType.VOID))
return;
if (!(mth.getRegion() instanceof Region))
return;
Region rootRegion = (Region) mth.getRegion();
if (rootRegion.getSubBlocks().isEmpty())
return;
IContainer lastCont = rootRegion.getSubBlocks().get(rootRegion.getSubBlocks().size() - 1);
if (lastCont instanceof BlockNode) {
BlockNode lastBlock = (BlockNode) lastCont;
List<InsnNode> insns = lastBlock.getInstructions();
int last = insns.size() - 1;
if (last >= 0
&& insns.get(last).getType() == InsnType.RETURN
&& insns.get(last).getArgsCount() == 0) {
insns.remove(last);
DepthRegionTraverser.traverse(mth, new ProcessTryCatchRegions(mth), startRegion);
if (mth.getLoopsCount() != 0) {
DepthRegionTraverser.traverse(mth, new ProcessLoopRegions(), startRegion);
}
if (mth.getReturnType().equals(ArgType.VOID)) {
DepthRegionTraverser.traverseAll(mth, new ProcessReturnInsns());
}
}
}
package jadx.core.dex.visitors.regions;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.LoopRegion;
public class ProcessLoopRegions extends AbstractRegionVisitor {
@Override
public void enterRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
}
}
}
package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AttributeFlag;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.LoopRegion;
import jadx.core.utils.RegionUtils;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
/**
* Remove unnecessary return instructions for void methods
*/
public class ProcessReturnInsns extends TracedRegionVisitor {
public class FinishRegions extends TracedRegionVisitor {
@Override
public void processBlockTraced(MethodNode mth, IBlock container, IRegion currentRegion) {
if (container.getClass() != BlockNode.class)
if (container.getClass() != BlockNode.class) {
return;
BlockNode block = (BlockNode) container;
// remove last return in void functions
/*
}
BlockNode block = (BlockNode) container;
if (block.getCleanSuccessors().isEmpty()
&& mth.getReturnType().equals(ArgType.VOID)) {
if (block.getAttributes().contains(AttributeFlag.RETURN)) {
List<InsnNode> insns = block.getInstructions();
int lastIndex = insns.size() - 1;
if (lastIndex != -1) {
InsnNode last = insns.get(lastIndex);
if (last.getType() == InsnType.RETURN
&& blockNotInLoop(mth, block)) {
insns.remove(lastIndex);
}
if (insns.size() == 1
&& blockNotInLoop(mth, block)
&& noTrailInstructions(block)) {
insns.remove(insns.size() - 1);
}
}
*/
}
private boolean blockNotInLoop(MethodNode mth, BlockNode block) {
if (mth.getLoopForBlock(block) != null)
if (mth.getLoopForBlock(block) != null) {
return false;
for (Iterator<IRegion> it = regionStack.descendingIterator(); it.hasNext(); ) {
IRegion region = it.next();
if (region.getClass() == LoopRegion.class)
}
for (IRegion region : regionStack) {
if (region.getClass() == LoopRegion.class) {
return false;
}
}
return true;
}
@Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
/**
* Check that there no code after this block in regions structure
*/
private boolean noTrailInstructions(BlockNode block) {
IContainer curContainer = block;
for (IRegion region : regionStack) {
List<IContainer> subBlocks = region.getSubBlocks();
if (!subBlocks.isEmpty()) {
ListIterator<IContainer> itSubBlock = subBlocks.listIterator(subBlocks.size());
while (itSubBlock.hasPrevious()) {
IContainer subBlock = itSubBlock.previous();
if (subBlock == curContainer) {
break;
} else if (RegionUtils.notEmpty(subBlock)) {
return false;
}
super.leaveRegion(mth, region);
}
}
curContainer = region;
}
return true;
}
}
......@@ -26,28 +26,31 @@ import org.slf4j.LoggerFactory;
/**
* Extract blocks to separate try/catch region
*/
public class MarkTryCatchRegions extends AbstractRegionVisitor {
private static final Logger LOG = LoggerFactory.getLogger(MarkTryCatchRegions.class);
public class ProcessTryCatchRegions extends AbstractRegionVisitor {
private static final Logger LOG = LoggerFactory.getLogger(ProcessTryCatchRegions.class);
private static final boolean DEBUG = false;
static {
if (DEBUG)
LOG.debug("Debug enabled for " + MarkTryCatchRegions.class);
if (DEBUG) {
LOG.debug("Debug enabled for " + ProcessTryCatchRegions.class);
}
}
private final Map<BlockNode, TryCatchBlock> tryBlocksMap = new HashMap<BlockNode, TryCatchBlock>(2);
public MarkTryCatchRegions(MethodNode mth) {
if (mth.isNoCode() || mth.getExceptionHandlers() == null)
public ProcessTryCatchRegions(MethodNode mth) {
if (mth.isNoCode() || mth.getExceptionHandlers() == null) {
return;
}
Set<TryCatchBlock> tryBlocks = new HashSet<TryCatchBlock>();
// collect all try/catch blocks
for (BlockNode block : mth.getBasicBlocks()) {
CatchAttr c = (CatchAttr) block.getAttributes().get(AttributeType.CATCH_BLOCK);
if (c != null)
if (c != null) {
tryBlocks.add(c.getTryBlock());
}
}
// for each try block search nearest dominator block
for (TryCatchBlock tb : tryBlocks) {
......@@ -71,9 +74,10 @@ public class MarkTryCatchRegions extends AbstractRegionVisitor {
bs.andNot(block.getDoms());
}
domBlocks = BlockUtils.bitsetToBlocks(mth, bs);
if (domBlocks.size() != 1)
if (domBlocks.size() != 1) {
throw new JadxRuntimeException(
"Exception block dominator not found, method:" + mth + ". bs: " + bs);
}
BlockNode domBlock = domBlocks.get(0);
......@@ -83,18 +87,16 @@ public class MarkTryCatchRegions extends AbstractRegionVisitor {
}
}
if (DEBUG && !tryBlocksMap.isEmpty())
LOG.debug("MarkTryCatchRegions: \n {} \n {}", mth, tryBlocksMap);
if (DEBUG && !tryBlocksMap.isEmpty()) {
LOG.debug("ProcessTryCatchRegions: \n {} \n {}", mth, tryBlocksMap);
}
}
@Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (tryBlocksMap.isEmpty())
if (tryBlocksMap.isEmpty() || !(region instanceof Region)) {
return;
if (!(region instanceof Region))
return;
}
// search dominator blocks in this region (don't need to go deeper)
for (BlockNode dominator : tryBlocksMap.keySet()) {
if (region.getSubBlocks().contains(dominator)) {
......@@ -132,8 +134,9 @@ public class MarkTryCatchRegions extends AbstractRegionVisitor {
}
}
if (newRegion.getSubBlocks().size() != 0) {
if (DEBUG)
LOG.debug("MarkTryCatchRegions mark: {}", newRegion);
if (DEBUG) {
LOG.debug("ProcessTryCatchRegions mark: {}", newRegion);
}
// replace first node by region
IContainer firstNode = newRegion.getSubBlocks().get(0);
int i = region.getSubBlocks().indexOf(firstNode);
......
......@@ -6,6 +6,7 @@ import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestTryCatch extends InternalJadxTest {
......@@ -28,5 +29,6 @@ public class TestTryCatch extends InternalJadxTest {
assertThat(code, containsString("try {"));
assertThat(code, containsString("Thread.sleep(50);"));
assertThat(code, containsString("} catch (InterruptedException e) {"));
assertThat(code, not(containsString("return")));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment