Commit aec98644 authored by Skylot's avatar Skylot

fix: support multi-exception catch blocks (#421)

parent b0e3cfed
package jadx.core.codegen;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
......@@ -11,6 +12,7 @@ import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.DeclareVariablesAttr;
import jadx.core.dex.attributes.nodes.ForceReturnAttr;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.SwitchNode;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.NamedArg;
......@@ -306,16 +308,23 @@ public class RegionGen extends InsnGen {
return;
}
code.startLine("} catch (");
if (handler.isCatchAll()) {
code.add("Throwable");
} else {
Iterator<ClassInfo> it = handler.getCatchTypes().iterator();
if (it.hasNext()) {
useClass(code, it.next());
}
while (it.hasNext()) {
code.add(" | ");
useClass(code, it.next());
}
}
code.add(' ');
InsnArg arg = handler.getArg();
if (arg instanceof RegisterArg) {
declareVar(code, (RegisterArg) arg);
code.add(mgen.getNameGen().assignArg((RegisterArg) arg));
} else if (arg instanceof NamedArg) {
if (handler.isCatchAll()) {
code.add("Throwable");
} else {
useClass(code, handler.getCatchType());
}
code.add(' ');
code.add(mgen.getNameGen().assignNamedArg((NamedArg) arg));
}
code.add(") {");
......
......@@ -2,12 +2,14 @@ package jadx.core.dex.info;
import java.io.File;
import org.jetbrains.annotations.NotNull;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.JadxRuntimeException;
public final class ClassInfo {
public final class ClassInfo implements Comparable<ClassInfo> {
private final ArgType type;
private String pkg;
......@@ -194,4 +196,9 @@ public final class ClassInfo {
}
return false;
}
@Override
public int compareTo(@NotNull ClassInfo o) {
return fullName.compareTo(o.fullName);
}
}
......@@ -96,16 +96,16 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode {
DexNode dex = parentClass.dex();
Code mthCode = dex.readCode(methodData);
regsCount = mthCode.getRegistersSize();
this.regsCount = mthCode.getRegistersSize();
initMethodTypes();
InsnDecoder decoder = new InsnDecoder(this);
decoder.decodeInsns(mthCode);
instructions = decoder.process();
codeSize = instructions.length;
this.instructions = decoder.process();
this.codeSize = instructions.length;
initTryCatches(mthCode);
initJumps();
initTryCatches(this, mthCode, instructions);
initJumps(instructions);
this.debugInfoOffset = mthCode.getDebugInfoOffset();
} catch (Exception e) {
......@@ -257,37 +257,37 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode {
return genericMap;
}
private void initTryCatches(Code mthCode) {
InsnNode[] insnByOffset = instructions;
private static void initTryCatches(MethodNode mth, Code mthCode, InsnNode[] insnByOffset) {
CatchHandler[] catchBlocks = mthCode.getCatchHandlers();
Try[] tries = mthCode.getTries();
if (catchBlocks.length == 0 && tries.length == 0) {
return;
}
int hc = 0;
int handlersCount = 0;
Set<Integer> addrs = new HashSet<>();
List<TryCatchBlock> catches = new ArrayList<>(catchBlocks.length);
for (CatchHandler handler : catchBlocks) {
TryCatchBlock tcBlock = new TryCatchBlock();
catches.add(tcBlock);
for (int i = 0; i < handler.getAddresses().length; i++) {
int addr = handler.getAddresses()[i];
ClassInfo type = ClassInfo.fromDex(parentClass.dex(), handler.getTypeIndexes()[i]);
tcBlock.addHandler(this, addr, type);
int[] handlerAddrArr = handler.getAddresses();
for (int i = 0; i < handlerAddrArr.length; i++) {
int addr = handlerAddrArr[i];
ClassInfo type = ClassInfo.fromDex(mth.dex(), handler.getTypeIndexes()[i]);
tcBlock.addHandler(mth, addr, type);
addrs.add(addr);
hc++;
handlersCount++;
}
int addr = handler.getCatchAllAddress();
if (addr >= 0) {
tcBlock.addHandler(this, addr, null);
tcBlock.addHandler(mth, addr, null);
addrs.add(addr);
hc++;
handlersCount++;
}
}
if (hc > 0 && hc != addrs.size()) {
if (handlersCount > 0 && handlersCount != addrs.size()) {
// resolve nested try blocks:
// inner block contains all handlers from outer block => remove these handlers from inner block
// each handler must be only in one try/catch block
......@@ -295,7 +295,7 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode {
for (TryCatchBlock ct2 : catches) {
if (ct1 != ct2 && ct2.containsAllHandlers(ct1)) {
for (ExceptionHandler h : ct1.getHandlers()) {
ct2.removeHandler(this, h);
ct2.removeHandler(mth, h);
h.setTryBlock(ct1);
}
}
......@@ -309,6 +309,7 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode {
for (ExceptionHandler eh : ct.getHandlers()) {
int addr = eh.getHandleOffset();
ExcHandlerAttr ehAttr = new ExcHandlerAttr(ct, eh);
// TODO: don't override existing attribute
insnByOffset[addr].addAttr(ehAttr);
}
}
......@@ -335,8 +336,7 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode {
}
}
private void initJumps() {
InsnNode[] insnByOffset = instructions;
private static void initJumps(InsnNode[] insnByOffset) {
for (int offset = 0; offset < insnByOffset.length; offset++) {
InsnNode insn = insnByOffset[offset];
if (insn == null) {
......@@ -484,7 +484,18 @@ public class MethodNode extends LineAttrNode implements ILoadable, IDexNode {
exceptionHandlers = new ArrayList<>(2);
} else {
for (ExceptionHandler h : exceptionHandlers) {
if (h == handler || h.getHandleOffset() == handler.getHandleOffset()) {
if (h.equals(handler)) {
return h;
}
if (h.getHandleOffset() == handler.getHandleOffset()) {
if (h.getTryBlock() == handler.getTryBlock()) {
for (ClassInfo catchType : handler.getCatchTypes()) {
h.addCatchType(catchType);
}
} else {
// same handlers from different try blocks
// will merge later
}
return h;
}
}
......
......@@ -30,6 +30,6 @@ public class ExcHandlerAttr implements IAttribute {
public String toString() {
return "ExcHandler: " + (handler.isFinally()
? " FINALLY"
: (handler.isCatchAll() ? "all" : handler.getCatchType()) + " " + handler.getArg());
: handler.catchTypeStr() + " " + handler.getArg());
}
}
package jadx.core.dex.trycatch;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import org.jetbrains.annotations.Nullable;
import jadx.core.Consts;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class ExceptionHandler {
private final ClassInfo catchType;
private final Set<ClassInfo> catchTypes = new TreeSet<>();
private final int handleOffset;
private BlockNode handlerBlock;
......@@ -23,17 +32,57 @@ public class ExceptionHandler {
private TryCatchBlock tryBlock;
private boolean isFinally;
public ExceptionHandler(int addr, ClassInfo type) {
public ExceptionHandler(int addr, @Nullable ClassInfo type) {
this.handleOffset = addr;
this.catchType = type;
addCatchType(type);
}
/**
* Add exception type to catch block
* @param type - null for 'all' or 'Throwable' handler
*/
public void addCatchType(@Nullable ClassInfo type) {
if (type != null) {
this.catchTypes.add(type);
} else {
if (!this.catchTypes.isEmpty()) {
throw new JadxRuntimeException("Null type added to not empty exception handler: " + this);
}
}
}
public void addCatchTypes(Collection<ClassInfo> types) {
for (ClassInfo type : types) {
addCatchType(type);
}
}
public ClassInfo getCatchType() {
return catchType;
public Set<ClassInfo> getCatchTypes() {
return catchTypes;
}
public ArgType getArgType() {
if (isCatchAll()) {
return ArgType.THROWABLE;
}
Set<ClassInfo> types = getCatchTypes();
if (types.size() == 1) {
return types.iterator().next().getType();
} else {
return ArgType.THROWABLE;
}
}
public boolean isCatchAll() {
return catchType == null || catchType.getFullName().equals(Consts.CLASS_THROWABLE);
if (catchTypes.isEmpty()) {
return true;
}
for (ClassInfo classInfo : catchTypes) {
if (classInfo.getFullName().equals(Consts.CLASS_THROWABLE)) {
return true;
}
}
return false;
}
public int getHandleOffset() {
......@@ -89,35 +138,30 @@ public class ExceptionHandler {
}
@Override
public int hashCode() {
return (catchType == null ? 0 : 31 * catchType.hashCode()) + handleOffset;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
ExceptionHandler other = (ExceptionHandler) obj;
if (catchType == null) {
if (other.catchType != null) {
return false;
}
} else if (!catchType.equals(other.catchType)) {
if (o == null || getClass() != o.getClass()) {
return false;
}
return handleOffset == other.handleOffset;
ExceptionHandler that = (ExceptionHandler) o;
return handleOffset == that.handleOffset &&
catchTypes.equals(that.catchTypes) &&
Objects.equals(tryBlock, that.tryBlock);
}
@Override
public int hashCode() {
return Objects.hash(catchTypes, handleOffset /*, tryBlock*/);
}
public String catchTypeStr() {
return catchTypes.isEmpty() ? "all" : Utils.listToString(catchTypes, " | ", ClassInfo::getShortName);
}
@Override
public String toString() {
return (catchType == null ? "all"
: catchType.getShortName()) + " -> " + InsnUtils.formatOffset(handleOffset);
return catchTypeStr() + " -> " + InsnUtils.formatOffset(handleOffset);
}
}
......@@ -5,6 +5,8 @@ import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.info.ClassInfo;
......@@ -40,12 +42,14 @@ public class TryCatchBlock {
return handlers.containsAll(tb.handlers);
}
public ExceptionHandler addHandler(MethodNode mth, int addr, ClassInfo type) {
public ExceptionHandler addHandler(MethodNode mth, int addr, @Nullable ClassInfo type) {
ExceptionHandler handler = new ExceptionHandler(addr, type);
handler = mth.addExceptionHandler(handler);
handlers.add(handler);
handler.setTryBlock(this);
return handler;
ExceptionHandler addedHandler = mth.addExceptionHandler(handler);
if (addedHandler == handler || addedHandler.getTryBlock() != this) {
handlers.add(addedHandler);
}
return addedHandler;
}
public void removeHandler(MethodNode mth, ExceptionHandler handler) {
......
......@@ -438,7 +438,7 @@ public class ModVisitor extends AbstractVisitor {
// result arg used both in this insn and exception handler,
RegisterArg resArg = insn.getResult();
ArgType type = excHandler.isCatchAll() ? ArgType.THROWABLE : excHandler.getCatchType().getType();
ArgType type = excHandler.getArgType();
String name = excHandler.isCatchAll() ? "th" : "e";
if (resArg.getName() == null) {
resArg.setName(name);
......
......@@ -48,7 +48,7 @@ public class BlockExceptionHandler extends AbstractVisitor {
return;
}
ExceptionHandler excHandler = handlerAttr.getHandler();
ArgType argType = excHandler.isCatchAll() ? ArgType.THROWABLE : excHandler.getCatchType().getType();
ArgType argType = excHandler.getArgType();
if (!block.getInstructions().isEmpty()) {
InsnNode me = block.getInstructions().get(0);
if (me.getType() == InsnType.MOVE_EXCEPTION) {
......
......@@ -2,6 +2,7 @@ package jadx.core.dex.visitors.blocksmaker;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
......@@ -19,6 +20,9 @@ import jadx.core.dex.nodes.Edge;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.trycatch.CatchAttr;
import jadx.core.dex.trycatch.ExcHandlerAttr;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.TryCatchBlock;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.exceptions.JadxOverflowException;
......@@ -360,7 +364,10 @@ public class BlockProcessor extends AbstractVisitor {
throw new JadxRuntimeException("Unreachable block: " + block);
}
}
if (mergeExceptionHandlers(mth)) {
removeMarkedBlocks(mth);
return true;
}
for (BlockNode block : basicBlocks) {
if (checkLoops(mth, block)) {
return true;
......@@ -446,6 +453,85 @@ public class BlockProcessor extends AbstractVisitor {
}
/**
* Merge handlers for multi-exception catch
*/
private static boolean mergeExceptionHandlers(MethodNode mth) {
for (BlockNode block : mth.getBasicBlocks()) {
ExcHandlerAttr excHandlerAttr = block.get(AType.EXC_HANDLER);
if (excHandlerAttr != null) {
List<BlockNode> blocksForMerge = collectExcHandlerBlocks(block, excHandlerAttr);
if (mergeHandlers(mth, blocksForMerge, excHandlerAttr)) {
return true;
}
}
}
return false;
}
private static List<BlockNode> collectExcHandlerBlocks(BlockNode block, ExcHandlerAttr excHandlerAttr) {
List<BlockNode> successors = block.getSuccessors();
if (successors.size() != 1) {
return Collections.emptyList();
}
RegisterArg reg = getMoveExceptionRegister(block);
if (reg == null) {
return Collections.emptyList();
}
TryCatchBlock tryBlock = excHandlerAttr.getTryBlock();
List<BlockNode> blocksForMerge = new ArrayList<>();
BlockNode nextBlock = successors.get(0);
for (BlockNode predBlock : nextBlock.getPredecessors()) {
if (predBlock != block
&& checkOtherExcHandler(predBlock, tryBlock, reg)) {
blocksForMerge.add(predBlock);
}
}
return blocksForMerge;
}
private static boolean checkOtherExcHandler(BlockNode predBlock, TryCatchBlock tryBlock, RegisterArg reg) {
ExcHandlerAttr otherExcHandlerAttr = predBlock.get(AType.EXC_HANDLER);
if (otherExcHandlerAttr == null) {
return false;
}
TryCatchBlock otherTryBlock = otherExcHandlerAttr.getTryBlock();
if (tryBlock != otherTryBlock) {
return false;
}
RegisterArg otherReg = getMoveExceptionRegister(predBlock);
if (otherReg == null || reg.getRegNum() != otherReg.getRegNum()) {
return false;
}
return true;
}
private static RegisterArg getMoveExceptionRegister(BlockNode block) {
if (block.getInstructions().isEmpty()) {
return null;
}
InsnNode insn = block.getInstructions().get(0);
if (insn.getType() != InsnType.MOVE_EXCEPTION) {
return null;
}
return insn.getResult();
}
private static boolean mergeHandlers(MethodNode mth, List<BlockNode> blocksForMerge, ExcHandlerAttr excHandlerAttr) {
if (blocksForMerge.isEmpty()) {
return false;
}
TryCatchBlock tryBlock = excHandlerAttr.getTryBlock();
for (BlockNode block : blocksForMerge) {
ExcHandlerAttr otherExcHandlerAttr = block.get(AType.EXC_HANDLER);
ExceptionHandler excHandler = otherExcHandlerAttr.getHandler();
excHandlerAttr.getHandler().addCatchTypes(excHandler.getCatchTypes());
tryBlock.removeHandler(mth, excHandler);
detachBlock(block);
}
return true;
}
/**
* Splice return block if several predecessors presents
*/
private static boolean splitReturn(MethodNode mth) {
......@@ -543,6 +629,20 @@ public class BlockProcessor extends AbstractVisitor {
});
}
private static void detachBlock(BlockNode block) {
for (BlockNode pred : block.getPredecessors()) {
pred.getSuccessors().remove(block);
pred.updateCleanSuccessors();
}
for (BlockNode successor : block.getSuccessors()) {
successor.getPredecessors().remove(block);
}
block.add(AFlag.REMOVE);
block.getInstructions().clear();
block.getPredecessors().clear();
block.getSuccessors().clear();
}
private static void clearBlocksState(MethodNode mth) {
mth.getBasicBlocks().forEach(block -> {
block.remove(AType.LOOP);
......
......@@ -44,6 +44,7 @@ public class BlockSplitter extends AbstractVisitor {
mth.initBasicBlocks();
splitBasicBlocks(mth);
removeJumpAttr(mth);
removeInsns(mth);
removeEmptyDetachedBlocks(mth);
initBlocksInTargetNodes(mth);
......@@ -296,6 +297,14 @@ public class BlockSplitter extends AbstractVisitor {
return block;
}
private void removeJumpAttr(MethodNode mth) {
for (BlockNode block : mth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
insn.remove(AType.JUMP);
}
}
}
private static void removeInsns(MethodNode mth) {
for (BlockNode block : mth.getBasicBlocks()) {
block.getInstructions().removeIf(insn -> {
......
......@@ -39,8 +39,16 @@ public class Utils {
if (objects == null) {
return "";
}
return listToString(objects, joiner, Object::toString);
}
public static <T> String listToString(Iterable<T> objects, Function<T, String> toStr) {
return listToString(objects, ", ", toStr);
}
public static <T> String listToString(Iterable<T> objects, String joiner, Function<T, String> toStr) {
StringBuilder sb = new StringBuilder();
listToString(sb, objects, joiner, Object::toString);
listToString(sb, objects, joiner, toStr);
return sb.toString();
}
......
......@@ -24,6 +24,10 @@ public abstract class SmaliTest extends IntegrationTest {
return getClassNodeFromSmali(path + File.separatorChar + clsName, clsName);
}
protected ClassNode getClassNodeFromSmaliWithPkg(String pkg, String clsName) {
return getClassNodeFromSmali(pkg + File.separatorChar + clsName, pkg + '.' + clsName);
}
protected ClassNode getClassNodeFromSmali(String clsName) {
return getClassNodeFromSmali(clsName, clsName);
}
......
package jadx.tests.integration.trycatch;
import java.security.ProviderException;
import java.time.DateTimeException;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestMultiExceptionCatch extends IntegrationTest {
public static class TestCls {
public void test() {
try {
System.out.println("Test");
} catch (ProviderException | DateTimeException e) {
throw new RuntimeException(e);
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("try {"));
assertThat(code, containsOne("} catch (ProviderException | DateTimeException e) {"));
assertThat(code, containsOne("throw new RuntimeException(e);"));
assertThat(code, not(containsString("RuntimeException e;")));
}
}
package jadx.tests.integration.trycatch;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestMultiExceptionCatchSameJump extends SmaliTest {
/*
public static class TestCls {
public void test() {
try {
System.out.println("Test");
} catch (ProviderException | DateTimeException e) {
throw new RuntimeException(e);
}
}
}
*/
@Test
public void test() {
ClassNode cls = getClassNodeFromSmaliWithPkg("trycatch", "TestMultiExceptionCatchSameJump");
String code = cls.getCode().toString();
assertThat(code, containsOne("try {"));
assertThat(code, containsOne("} catch (ProviderException | DateTimeException e) {"));
assertThat(code, containsOne("throw new RuntimeException(e);"));
assertThat(code, not(containsString("RuntimeException e;")));
}
}
.class public Ltrycatch/TestMultiExceptionCatchSameJump;
.super Ljava/lang/Object;
.source "TestMultiExceptionCatchSameJump.java"
.method public test()V
.locals 2
.line 17
:try_start_0
sget-object v0, Ljava/lang/System;->out:Ljava/io/PrintStream;
const-string v1, "Test"
invoke-virtual {v0, v1}, Ljava/io/PrintStream;->println(Ljava/lang/String;)V
:try_end_0
.catch Ljava/security/ProviderException; {:try_start_0 .. :try_end_0} :catch_0
.catch Ljava/time/DateTimeException; {:try_start_0 .. :try_end_0} :catch_0
.line 20
nop
.line 22
return-void
.line 18
:catch_0
move-exception v0
.line 19
.local v0, "e":Ljava/lang/RuntimeException;
new-instance v1, Ljava/lang/RuntimeException;
invoke-direct {v1, v0}, Ljava/lang/RuntimeException;-><init>(Ljava/lang/Throwable;)V
throw v1
.end method
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment