Commit b2f0f025 authored by Skylot's avatar Skylot

core: fix incorrectly removed 'return' in 'switch' block (fix #70)

parent 71f24911
...@@ -35,6 +35,7 @@ import java.util.Map; ...@@ -35,6 +35,7 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -442,6 +443,7 @@ public class MethodNode extends LineAttrNode implements ILoadable { ...@@ -442,6 +443,7 @@ public class MethodNode extends LineAttrNode implements ILoadable {
loops.add(loop); loops.add(loop);
} }
@Nullable
public LoopInfo getLoopForBlock(BlockNode block) { public LoopInfo getLoopForBlock(BlockNode block) {
if (loops.isEmpty()) { if (loops.isEmpty()) {
return null; return null;
......
...@@ -30,6 +30,7 @@ import jadx.core.utils.ErrorsCounter; ...@@ -30,6 +30,7 @@ import jadx.core.utils.ErrorsCounter;
import jadx.core.utils.InstructionRemover; import jadx.core.utils.InstructionRemover;
import jadx.core.utils.RegionUtils; import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxOverflowException; import jadx.core.utils.exceptions.JadxOverflowException;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet; import java.util.BitSet;
...@@ -690,7 +691,9 @@ public class RegionMaker { ...@@ -690,7 +691,9 @@ public class RegionMaker {
Map<BlockNode, List<Object>> blocksMap = new LinkedHashMap<BlockNode, List<Object>>(len); Map<BlockNode, List<Object>> blocksMap = new LinkedHashMap<BlockNode, List<Object>>(len);
for (Map.Entry<Integer, List<Object>> entry : casesMap.entrySet()) { for (Map.Entry<Integer, List<Object>> entry : casesMap.entrySet()) {
BlockNode c = getBlockByOffset(entry.getKey(), block.getSuccessors()); BlockNode c = getBlockByOffset(entry.getKey(), block.getSuccessors());
assert c != null; if (c == null) {
throw new JadxRuntimeException("Switch block not found by offset: " + entry.getKey());
}
blocksMap.put(c, entry.getValue()); blocksMap.put(c, entry.getValue());
} }
BlockNode defCase = getBlockByOffset(insn.getDefaultCaseOffset(), block.getSuccessors()); BlockNode defCase = getBlockByOffset(insn.getDefaultCaseOffset(), block.getSuccessors());
......
...@@ -9,6 +9,7 @@ import jadx.core.dex.nodes.IContainer; ...@@ -9,6 +9,7 @@ import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.loops.LoopRegion; import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
...@@ -31,6 +32,13 @@ public class ReturnVisitor extends AbstractVisitor { ...@@ -31,6 +32,13 @@ public class ReturnVisitor extends AbstractVisitor {
} }
private static final class ReturnRemoverVisitor extends TracedRegionVisitor { private static final class ReturnRemoverVisitor extends TracedRegionVisitor {
@Override
public boolean enterRegion(MethodNode mth, IRegion region) {
super.enterRegion(mth, region);
return !(region instanceof SwitchRegion);
}
@Override @Override
public void processBlockTraced(MethodNode mth, IBlock container, IRegion currentRegion) { public void processBlockTraced(MethodNode mth, IBlock container, IRegion currentRegion) {
if (container.getClass() != BlockNode.class) { if (container.getClass() != BlockNode.class) {
......
package jadx.tests.integration.switches;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
public class TestSwitch3 extends IntegrationTest {
public static class TestCls {
private int i;
void test(int a) {
switch (a) {
case 1:
i = 1;
return;
case 2:
case 3:
i = 2;
return;
default:
i = 4;
break;
}
i = 5;
}
public void check() {
test(1);
assertThat(i, is(1));
test(2);
assertThat(i, is(2));
test(3);
assertThat(i, is(2));
test(4);
assertThat(i, is(5));
test(10);
assertThat(i, is(5));
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, countString(0, "break;"));
assertThat(code, countString(3, "return;"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment