Commit 49e234d9 authored by Skylot's avatar Skylot

fix: improve finally extraction

parent a587ce88
...@@ -5,6 +5,7 @@ import com.android.dx.io.instructions.DecodedInstruction; ...@@ -5,6 +5,7 @@ import com.android.dx.io.instructions.DecodedInstruction;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils; import jadx.core.utils.InsnUtils;
...@@ -68,7 +69,23 @@ public class ArithNode extends InsnNode { ...@@ -68,7 +69,23 @@ public class ArithNode extends InsnNode {
return false; return false;
} }
ArithNode other = (ArithNode) obj; ArithNode other = (ArithNode) obj;
return op == other.op; return op == other.op && isSameLiteral(other);
}
private boolean isSameLiteral(ArithNode other) {
InsnArg thisSecond = getArg(1);
InsnArg otherSecond = other.getArg(1);
if (thisSecond.isLiteral() != otherSecond.isLiteral()) {
return false;
}
if (!thisSecond.isLiteral()) {
// both not literals
return true;
}
// both literals
long thisLit = ((LiteralArg) thisSecond).getLiteral();
long otherLit = ((LiteralArg) otherSecond).getLiteral();
return thisLit == otherLit;
} }
@Override @Override
......
...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors; ...@@ -2,6 +2,7 @@ package jadx.core.dex.visitors;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
...@@ -98,16 +99,16 @@ public class MarkFinallyVisitor extends AbstractVisitor { ...@@ -98,16 +99,16 @@ public class MarkFinallyVisitor extends AbstractVisitor {
*/ */
private static boolean extractFinally(MethodNode mth, ExceptionHandler allHandler) { private static boolean extractFinally(MethodNode mth, ExceptionHandler allHandler) {
List<BlockNode> handlerBlocks = new ArrayList<>(); List<BlockNode> handlerBlocks = new ArrayList<>();
for (BlockNode block : allHandler.getBlocks()) { BlockNode handlerBlock = allHandler.getHandlerBlock();
for (BlockNode block : BlockUtils.collectBlocksDominatedByWithExcHandlers(handlerBlock, handlerBlock)) {
InsnNode lastInsn = BlockUtils.getLastInsn(block); InsnNode lastInsn = BlockUtils.getLastInsn(block);
if (lastInsn != null) { if (lastInsn != null && lastInsn.getType() == InsnType.THROW) {
InsnType insnType = lastInsn.getType(); break;
if (insnType != InsnType.MOVE_EXCEPTION && insnType != InsnType.THROW) {
handlerBlocks.add(block);
}
} }
handlerBlocks.add(block);
} }
if (handlerBlocks.isEmpty()) { if (handlerBlocks.isEmpty() || BlockUtils.isAllBlocksEmpty(handlerBlocks)) {
// remove empty catch // remove empty catch
allHandler.getTryBlock().removeHandler(mth, allHandler); allHandler.getTryBlock().removeHandler(mth, allHandler);
return true; return true;
...@@ -126,6 +127,8 @@ public class MarkFinallyVisitor extends AbstractVisitor { ...@@ -126,6 +127,8 @@ public class MarkFinallyVisitor extends AbstractVisitor {
for (BlockNode checkBlock : otherHandler.getBlocks()) { for (BlockNode checkBlock : otherHandler.getBlocks()) {
if (searchDuplicateInsns(checkBlock, extractInfo)) { if (searchDuplicateInsns(checkBlock, extractInfo)) {
break; break;
} else {
extractInfo.getFinallyInsnsSlice().resetIncomplete();
} }
} }
} }
...@@ -149,11 +152,16 @@ public class MarkFinallyVisitor extends AbstractVisitor { ...@@ -149,11 +152,16 @@ public class MarkFinallyVisitor extends AbstractVisitor {
boolean found = false; boolean found = false;
for (BlockNode splitter : splitters) { for (BlockNode splitter : splitters) {
BlockNode start = splitter.getCleanSuccessors().get(0); BlockNode start = splitter.getCleanSuccessors().get(0);
List<BlockNode> list = BlockUtils.collectBlocksDominatedBy(splitter, start); List<BlockNode> list = new ArrayList<>();
for (BlockNode block : list) { list.add(start);
list.addAll(BlockUtils.collectBlocksDominatedByWithExcHandlers(start, start));
Set<BlockNode> checkSet = new LinkedHashSet<>(list);
for (BlockNode block : checkSet) {
if (searchDuplicateInsns(block, extractInfo)) { if (searchDuplicateInsns(block, extractInfo)) {
found = true; found = true;
break; break;
} else {
extractInfo.getFinallyInsnsSlice().resetIncomplete();
} }
} }
} }
...@@ -276,12 +284,15 @@ public class MarkFinallyVisitor extends AbstractVisitor { ...@@ -276,12 +284,15 @@ public class MarkFinallyVisitor extends AbstractVisitor {
&& !checkBlocksTree(dupBlock, startBlock, dupSlice, extractInfo)) { && !checkBlocksTree(dupBlock, startBlock, dupSlice, extractInfo)) {
return null; return null;
} }
return checkSlice(dupSlice); return checkTempSlice(dupSlice);
} }
@Nullable @Nullable
private static InsnsSlice checkSlice(InsnsSlice slice) { private static InsnsSlice checkTempSlice(InsnsSlice slice) {
List<InsnNode> insnsList = slice.getInsnsList(); List<InsnNode> insnsList = slice.getInsnsList();
if (insnsList.isEmpty()) {
return null;
}
// ignore slice with only one 'if' insn // ignore slice with only one 'if' insn
if (insnsList.size() == 1) { if (insnsList.size() == 1) {
InsnNode insnNode = insnsList.get(0); InsnNode insnNode = insnsList.get(0);
...@@ -384,8 +395,8 @@ public class MarkFinallyVisitor extends AbstractVisitor { ...@@ -384,8 +395,8 @@ public class MarkFinallyVisitor extends AbstractVisitor {
InsnsSlice dupSlice, FinallyExtractInfo extractInfo) { InsnsSlice dupSlice, FinallyExtractInfo extractInfo) {
InsnsSlice finallySlice = extractInfo.getFinallyInsnsSlice(); InsnsSlice finallySlice = extractInfo.getFinallyInsnsSlice();
List<BlockNode> finallyCS = finallyBlock.getCleanSuccessors(); List<BlockNode> finallyCS = finallyBlock.getSuccessors();
List<BlockNode> dupCS = dupBlock.getCleanSuccessors(); List<BlockNode> dupCS = dupBlock.getSuccessors();
if (finallyCS.size() == dupCS.size()) { if (finallyCS.size() == dupCS.size()) {
for (int i = 0; i < finallyCS.size(); i++) { for (int i = 0; i < finallyCS.size(); i++) {
BlockNode finSBlock = finallyCS.get(i); BlockNode finSBlock = finallyCS.get(i);
...@@ -410,17 +421,21 @@ public class MarkFinallyVisitor extends AbstractVisitor { ...@@ -410,17 +421,21 @@ public class MarkFinallyVisitor extends AbstractVisitor {
private static boolean compareBlocks(BlockNode dupBlock, BlockNode finallyBlock, InsnsSlice dupSlice, FinallyExtractInfo extractInfo) { private static boolean compareBlocks(BlockNode dupBlock, BlockNode finallyBlock, InsnsSlice dupSlice, FinallyExtractInfo extractInfo) {
List<InsnNode> dupInsns = dupBlock.getInstructions(); List<InsnNode> dupInsns = dupBlock.getInstructions();
List<InsnNode> finallyInsns = finallyBlock.getInstructions(); List<InsnNode> finallyInsns = finallyBlock.getInstructions();
if (dupInsns.size() < finallyInsns.size()) { int dupInsnCount = dupInsns.size();
int finallyInsnCount = finallyInsns.size();
if (finallyInsnCount == 0) {
return dupInsnCount == 0;
}
if (dupInsnCount < finallyInsnCount) {
return false; return false;
} }
int size = finallyInsns.size(); for (int i = 0; i < finallyInsnCount; i++) {
for (int i = 0; i < size; i++) {
if (!sameInsns(dupInsns.get(i), finallyInsns.get(i))) { if (!sameInsns(dupInsns.get(i), finallyInsns.get(i))) {
return false; return false;
} }
} }
if (dupInsns.size() > finallyInsns.size()) { if (dupInsnCount > finallyInsnCount) {
dupSlice.addInsns(dupBlock, 0, finallyInsns.size()); dupSlice.addInsns(dupBlock, 0, finallyInsnCount);
dupSlice.setComplete(true); dupSlice.setComplete(true);
InsnsSlice finallyInsnsSlice = extractInfo.getFinallyInsnsSlice(); InsnsSlice finallyInsnsSlice = extractInfo.getFinallyInsnsSlice();
finallyInsnsSlice.addBlock(finallyBlock); finallyInsnsSlice.addBlock(finallyBlock);
......
...@@ -53,6 +53,13 @@ public class InsnsSlice { ...@@ -53,6 +53,13 @@ public class InsnsSlice {
return set; return set;
} }
public void resetIncomplete() {
if (!complete) {
insnsList.clear();
insnMap.clear();
}
}
public boolean isComplete() { public boolean isComplete() {
return complete; return complete;
} }
......
...@@ -461,20 +461,30 @@ public class BlockUtils { ...@@ -461,20 +461,30 @@ public class BlockUtils {
*/ */
public static List<BlockNode> collectBlocksDominatedBy(BlockNode dominator, BlockNode start) { public static List<BlockNode> collectBlocksDominatedBy(BlockNode dominator, BlockNode start) {
List<BlockNode> result = new ArrayList<>(); List<BlockNode> result = new ArrayList<>();
Set<BlockNode> visited = new HashSet<>(); collectWhileDominates(dominator, start, result, new HashSet<>(), false);
collectWhileDominates(dominator, start, result, visited);
return result; return result;
} }
private static void collectWhileDominates(BlockNode dominator, BlockNode child, List<BlockNode> result, Set<BlockNode> visited) { /**
* Collect all block dominated by 'dominator', starting from 'start', include exception handlers
*/
public static List<BlockNode> collectBlocksDominatedByWithExcHandlers(BlockNode dominator, BlockNode start) {
List<BlockNode> result = new ArrayList<>();
collectWhileDominates(dominator, start, result, new HashSet<>(), true);
return result;
}
private static void collectWhileDominates(BlockNode dominator, BlockNode child, List<BlockNode> result,
Set<BlockNode> visited, boolean includeExcHandlers) {
if (visited.contains(child)) { if (visited.contains(child)) {
return; return;
} }
visited.add(child); visited.add(child);
for (BlockNode node : child.getCleanSuccessors()) { List<BlockNode> successors = includeExcHandlers ? child.getSuccessors() : child.getCleanSuccessors();
for (BlockNode node : successors) {
if (node.isDominator(dominator)) { if (node.isDominator(dominator)) {
result.add(node); result.add(node);
collectWhileDominates(dominator, node, result, visited); collectWhileDominates(dominator, node, result, visited, includeExcHandlers);
} }
} }
} }
......
...@@ -36,7 +36,7 @@ public class TestFinallyExtract extends IntegrationTest { ...@@ -36,7 +36,7 @@ public class TestFinallyExtract extends IntegrationTest {
public void check() { public void check() {
test(); test();
assertEquals(result, 1); assertEquals(1, result);
} }
} }
......
...@@ -14,16 +14,17 @@ public class TestTryCatchFinally extends IntegrationTest { ...@@ -14,16 +14,17 @@ public class TestTryCatchFinally extends IntegrationTest {
public static class TestCls { public static class TestCls {
public boolean f; public boolean f;
@SuppressWarnings("ConstantConditions")
private boolean test(Object obj) { private boolean test(Object obj) {
this.f = false; this.f = false;
try { try {
exc(obj); exc(obj);
} catch (Exception e) { } catch (Exception e) {
e.getMessage(); e.printStackTrace();
} finally { } finally {
f = true; this.f = true;
} }
return f; return this.f;
} }
private static boolean exc(Object obj) throws Exception { private static boolean exc(Object obj) throws Exception {
...@@ -46,9 +47,9 @@ public class TestTryCatchFinally extends IntegrationTest { ...@@ -46,9 +47,9 @@ public class TestTryCatchFinally extends IntegrationTest {
assertThat(code, containsOne("exc(obj);")); assertThat(code, containsOne("exc(obj);"));
assertThat(code, containsOne("} catch (Exception e) {")); assertThat(code, containsOne("} catch (Exception e) {"));
assertThat(code, containsOne("e.getMessage();")); assertThat(code, containsOne("e.printStackTrace();"));
assertThat(code, containsOne("} finally {")); assertThat(code, containsOne("} finally {"));
assertThat(code, containsOne("f = true;")); assertThat(code, containsOne("this.f = true;"));
assertThat(code, containsOne("return this.f;")); assertThat(code, containsOne("return this.f;"));
} }
} }
...@@ -5,12 +5,35 @@ import org.junit.jupiter.api.Test; ...@@ -5,12 +5,35 @@ import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest; import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.not;
public class TestTryCatchFinally10 extends SmaliTest { public class TestTryCatchFinally10 extends SmaliTest {
// @formatter:off
/*
public static String test(Context context, int i) {
CommonContracts.requireNonNull(context);
InputStream inputStream = null;
try {
inputStream = context.getResources().openRawResource(i);
Scanner useDelimiter = new Scanner(inputStream).useDelimiter("\\A");
return useDelimiter.hasNext() ? useDelimiter.next() : "";
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
l.logException(LogLevel.ERROR, e);
}
}
}
}
*/
// @formatter:on
@Test @Test
public void test() { public void test() {
disableCompilation(); disableCompilation();
...@@ -18,5 +41,8 @@ public class TestTryCatchFinally10 extends SmaliTest { ...@@ -18,5 +41,8 @@ public class TestTryCatchFinally10 extends SmaliTest {
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, not(containsString("boolean z = null;"))); assertThat(code, not(containsString("boolean z = null;")));
assertThat(code, not(containsString("} catch (Throwable")));
assertThat(code, containsOne("} finally {"));
assertThat(code, containsOne(".close();"));
} }
} }
...@@ -6,13 +6,11 @@ import java.io.OutputStream; ...@@ -6,13 +6,11 @@ import java.io.OutputStream;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import jadx.NotYetImplemented;
import jadx.core.clsp.NClass; import jadx.core.clsp.NClass;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest; import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne; import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
public class TestTryCatchFinally2 extends IntegrationTest { public class TestTryCatchFinally2 extends IntegrationTest {
...@@ -57,13 +55,4 @@ public class TestTryCatchFinally2 extends IntegrationTest { ...@@ -57,13 +55,4 @@ public class TestTryCatchFinally2 extends IntegrationTest {
assertThat(code, containsOne("for (NClass cls : this.classes) {")); assertThat(code, containsOne("for (NClass cls : this.classes) {"));
assertThat(code, containsOne("for (NClass cls2 : this.classes) {")); assertThat(code, containsOne("for (NClass cls2 : this.classes) {"));
} }
@Test
@NotYetImplemented
public void test2() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, countString(2, "for (NClass cls : classes) {"));
}
} }
...@@ -6,7 +6,6 @@ import java.util.Scanner; ...@@ -6,7 +6,6 @@ import java.util.Scanner;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import jadx.NotYetImplemented;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest; import jadx.tests.api.IntegrationTest;
...@@ -33,14 +32,14 @@ public class TestTryCatchFinally9 extends IntegrationTest { ...@@ -33,14 +32,14 @@ public class TestTryCatchFinally9 extends IntegrationTest {
} }
@Test @Test
@NotYetImplemented("finally extraction")
public void test() { public void test() {
ClassNode cls = getClassNode(TestCls.class); ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, not(containsString("JADX INFO: finally extract failed"))); assertThat(code, not(containsString("JADX INFO: finally extract failed")));
assertThat(code, not(containsString("throw"))); assertThat(code, not(containsString(indent() + "throw ")));
assertThat(code, containsOne("} finally {")); assertThat(code, containsOne("} finally {"));
assertThat(code, containsOne("if (input != null) {")); assertThat(code, containsOne("if (input != null) {"));
assertThat(code, containsOne("input.close();"));
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment