Commit 49e234d9 authored by Skylot's avatar Skylot

fix: improve finally extraction

parent a587ce88
......@@ -5,6 +5,7 @@ import com.android.dx.io.instructions.DecodedInstruction;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils;
......@@ -68,7 +69,23 @@ public class ArithNode extends InsnNode {
return false;
}
ArithNode other = (ArithNode) obj;
return op == other.op;
return op == other.op && isSameLiteral(other);
}
private boolean isSameLiteral(ArithNode other) {
InsnArg thisSecond = getArg(1);
InsnArg otherSecond = other.getArg(1);
if (thisSecond.isLiteral() != otherSecond.isLiteral()) {
return false;
}
if (!thisSecond.isLiteral()) {
// both not literals
return true;
}
// both literals
long thisLit = ((LiteralArg) thisSecond).getLiteral();
long otherLit = ((LiteralArg) otherSecond).getLiteral();
return thisLit == otherLit;
}
@Override
......
......@@ -2,6 +2,7 @@ package jadx.core.dex.visitors;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
......@@ -98,16 +99,16 @@ public class MarkFinallyVisitor extends AbstractVisitor {
*/
private static boolean extractFinally(MethodNode mth, ExceptionHandler allHandler) {
List<BlockNode> handlerBlocks = new ArrayList<>();
for (BlockNode block : allHandler.getBlocks()) {
BlockNode handlerBlock = allHandler.getHandlerBlock();
for (BlockNode block : BlockUtils.collectBlocksDominatedByWithExcHandlers(handlerBlock, handlerBlock)) {
InsnNode lastInsn = BlockUtils.getLastInsn(block);
if (lastInsn != null) {
InsnType insnType = lastInsn.getType();
if (insnType != InsnType.MOVE_EXCEPTION && insnType != InsnType.THROW) {
handlerBlocks.add(block);
}
if (lastInsn != null && lastInsn.getType() == InsnType.THROW) {
break;
}
handlerBlocks.add(block);
}
if (handlerBlocks.isEmpty()) {
if (handlerBlocks.isEmpty() || BlockUtils.isAllBlocksEmpty(handlerBlocks)) {
// remove empty catch
allHandler.getTryBlock().removeHandler(mth, allHandler);
return true;
......@@ -126,6 +127,8 @@ public class MarkFinallyVisitor extends AbstractVisitor {
for (BlockNode checkBlock : otherHandler.getBlocks()) {
if (searchDuplicateInsns(checkBlock, extractInfo)) {
break;
} else {
extractInfo.getFinallyInsnsSlice().resetIncomplete();
}
}
}
......@@ -149,11 +152,16 @@ public class MarkFinallyVisitor extends AbstractVisitor {
boolean found = false;
for (BlockNode splitter : splitters) {
BlockNode start = splitter.getCleanSuccessors().get(0);
List<BlockNode> list = BlockUtils.collectBlocksDominatedBy(splitter, start);
for (BlockNode block : list) {
List<BlockNode> list = new ArrayList<>();
list.add(start);
list.addAll(BlockUtils.collectBlocksDominatedByWithExcHandlers(start, start));
Set<BlockNode> checkSet = new LinkedHashSet<>(list);
for (BlockNode block : checkSet) {
if (searchDuplicateInsns(block, extractInfo)) {
found = true;
break;
} else {
extractInfo.getFinallyInsnsSlice().resetIncomplete();
}
}
}
......@@ -276,12 +284,15 @@ public class MarkFinallyVisitor extends AbstractVisitor {
&& !checkBlocksTree(dupBlock, startBlock, dupSlice, extractInfo)) {
return null;
}
return checkSlice(dupSlice);
return checkTempSlice(dupSlice);
}
@Nullable
private static InsnsSlice checkSlice(InsnsSlice slice) {
private static InsnsSlice checkTempSlice(InsnsSlice slice) {
List<InsnNode> insnsList = slice.getInsnsList();
if (insnsList.isEmpty()) {
return null;
}
// ignore slice with only one 'if' insn
if (insnsList.size() == 1) {
InsnNode insnNode = insnsList.get(0);
......@@ -384,8 +395,8 @@ public class MarkFinallyVisitor extends AbstractVisitor {
InsnsSlice dupSlice, FinallyExtractInfo extractInfo) {
InsnsSlice finallySlice = extractInfo.getFinallyInsnsSlice();
List<BlockNode> finallyCS = finallyBlock.getCleanSuccessors();
List<BlockNode> dupCS = dupBlock.getCleanSuccessors();
List<BlockNode> finallyCS = finallyBlock.getSuccessors();
List<BlockNode> dupCS = dupBlock.getSuccessors();
if (finallyCS.size() == dupCS.size()) {
for (int i = 0; i < finallyCS.size(); i++) {
BlockNode finSBlock = finallyCS.get(i);
......@@ -410,17 +421,21 @@ public class MarkFinallyVisitor extends AbstractVisitor {
private static boolean compareBlocks(BlockNode dupBlock, BlockNode finallyBlock, InsnsSlice dupSlice, FinallyExtractInfo extractInfo) {
List<InsnNode> dupInsns = dupBlock.getInstructions();
List<InsnNode> finallyInsns = finallyBlock.getInstructions();
if (dupInsns.size() < finallyInsns.size()) {
int dupInsnCount = dupInsns.size();
int finallyInsnCount = finallyInsns.size();
if (finallyInsnCount == 0) {
return dupInsnCount == 0;
}
if (dupInsnCount < finallyInsnCount) {
return false;
}
int size = finallyInsns.size();
for (int i = 0; i < size; i++) {
for (int i = 0; i < finallyInsnCount; i++) {
if (!sameInsns(dupInsns.get(i), finallyInsns.get(i))) {
return false;
}
}
if (dupInsns.size() > finallyInsns.size()) {
dupSlice.addInsns(dupBlock, 0, finallyInsns.size());
if (dupInsnCount > finallyInsnCount) {
dupSlice.addInsns(dupBlock, 0, finallyInsnCount);
dupSlice.setComplete(true);
InsnsSlice finallyInsnsSlice = extractInfo.getFinallyInsnsSlice();
finallyInsnsSlice.addBlock(finallyBlock);
......
......@@ -53,6 +53,13 @@ public class InsnsSlice {
return set;
}
public void resetIncomplete() {
if (!complete) {
insnsList.clear();
insnMap.clear();
}
}
public boolean isComplete() {
return complete;
}
......
......@@ -461,20 +461,30 @@ public class BlockUtils {
*/
public static List<BlockNode> collectBlocksDominatedBy(BlockNode dominator, BlockNode start) {
List<BlockNode> result = new ArrayList<>();
Set<BlockNode> visited = new HashSet<>();
collectWhileDominates(dominator, start, result, visited);
collectWhileDominates(dominator, start, result, new HashSet<>(), false);
return result;
}
private static void collectWhileDominates(BlockNode dominator, BlockNode child, List<BlockNode> result, Set<BlockNode> visited) {
/**
* Collect all block dominated by 'dominator', starting from 'start', include exception handlers
*/
public static List<BlockNode> collectBlocksDominatedByWithExcHandlers(BlockNode dominator, BlockNode start) {
List<BlockNode> result = new ArrayList<>();
collectWhileDominates(dominator, start, result, new HashSet<>(), true);
return result;
}
private static void collectWhileDominates(BlockNode dominator, BlockNode child, List<BlockNode> result,
Set<BlockNode> visited, boolean includeExcHandlers) {
if (visited.contains(child)) {
return;
}
visited.add(child);
for (BlockNode node : child.getCleanSuccessors()) {
List<BlockNode> successors = includeExcHandlers ? child.getSuccessors() : child.getCleanSuccessors();
for (BlockNode node : successors) {
if (node.isDominator(dominator)) {
result.add(node);
collectWhileDominates(dominator, node, result, visited);
collectWhileDominates(dominator, node, result, visited, includeExcHandlers);
}
}
}
......
......@@ -36,7 +36,7 @@ public class TestFinallyExtract extends IntegrationTest {
public void check() {
test();
assertEquals(result, 1);
assertEquals(1, result);
}
}
......
......@@ -14,16 +14,17 @@ public class TestTryCatchFinally extends IntegrationTest {
public static class TestCls {
public boolean f;
@SuppressWarnings("ConstantConditions")
private boolean test(Object obj) {
this.f = false;
try {
exc(obj);
} catch (Exception e) {
e.getMessage();
e.printStackTrace();
} finally {
f = true;
this.f = true;
}
return f;
return this.f;
}
private static boolean exc(Object obj) throws Exception {
......@@ -46,9 +47,9 @@ public class TestTryCatchFinally extends IntegrationTest {
assertThat(code, containsOne("exc(obj);"));
assertThat(code, containsOne("} catch (Exception e) {"));
assertThat(code, containsOne("e.getMessage();"));
assertThat(code, containsOne("e.printStackTrace();"));
assertThat(code, containsOne("} finally {"));
assertThat(code, containsOne("f = true;"));
assertThat(code, containsOne("this.f = true;"));
assertThat(code, containsOne("return this.f;"));
}
}
......@@ -5,12 +5,35 @@ import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.not;
public class TestTryCatchFinally10 extends SmaliTest {
// @formatter:off
/*
public static String test(Context context, int i) {
CommonContracts.requireNonNull(context);
InputStream inputStream = null;
try {
inputStream = context.getResources().openRawResource(i);
Scanner useDelimiter = new Scanner(inputStream).useDelimiter("\\A");
return useDelimiter.hasNext() ? useDelimiter.next() : "";
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
l.logException(LogLevel.ERROR, e);
}
}
}
}
*/
// @formatter:on
@Test
public void test() {
disableCompilation();
......@@ -18,5 +41,8 @@ public class TestTryCatchFinally10 extends SmaliTest {
String code = cls.getCode().toString();
assertThat(code, not(containsString("boolean z = null;")));
assertThat(code, not(containsString("} catch (Throwable")));
assertThat(code, containsOne("} finally {"));
assertThat(code, containsOne(".close();"));
}
}
......@@ -6,13 +6,11 @@ import java.io.OutputStream;
import org.junit.jupiter.api.Test;
import jadx.NotYetImplemented;
import jadx.core.clsp.NClass;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.hamcrest.MatcherAssert.assertThat;
public class TestTryCatchFinally2 extends IntegrationTest {
......@@ -57,13 +55,4 @@ public class TestTryCatchFinally2 extends IntegrationTest {
assertThat(code, containsOne("for (NClass cls : this.classes) {"));
assertThat(code, containsOne("for (NClass cls2 : this.classes) {"));
}
@Test
@NotYetImplemented
public void test2() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, countString(2, "for (NClass cls : classes) {"));
}
}
......@@ -6,7 +6,6 @@ import java.util.Scanner;
import org.junit.jupiter.api.Test;
import jadx.NotYetImplemented;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
......@@ -33,14 +32,14 @@ public class TestTryCatchFinally9 extends IntegrationTest {
}
@Test
@NotYetImplemented("finally extraction")
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, not(containsString("JADX INFO: finally extract failed")));
assertThat(code, not(containsString("throw")));
assertThat(code, not(containsString(indent() + "throw ")));
assertThat(code, containsOne("} finally {"));
assertThat(code, containsOne("if (input != null) {"));
assertThat(code, containsOne("input.close();"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment