Commit b5a9389c authored by Skylot's avatar Skylot

core: fix variables inline in 'catch' block

parent d905c96f
...@@ -72,8 +72,10 @@ public class NameGen { ...@@ -72,8 +72,10 @@ public class NameGen {
} }
public String useArg(RegisterArg arg) { public String useArg(RegisterArg arg) {
String name = makeArgName(arg); String name = arg.getName();
varNames.add(name); if (name == null) {
return getFallbackName(arg);
}
return name; return name;
} }
...@@ -96,14 +98,10 @@ public class NameGen { ...@@ -96,14 +98,10 @@ public class NameGen {
} }
private String makeArgName(RegisterArg arg) { private String makeArgName(RegisterArg arg) {
String name = arg.getName();
if (fallback) { if (fallback) {
String base = "r" + arg.getRegNum(); return getFallbackName(arg);
if (name != null && !name.equals("this")) {
return base + "_" + name;
}
return base;
} }
String name = arg.getName();
String varName; String varName;
if (name != null) { if (name != null) {
if ("this".equals(name)) { if ("this".equals(name)) {
...@@ -119,6 +117,15 @@ public class NameGen { ...@@ -119,6 +117,15 @@ public class NameGen {
return varName; return varName;
} }
private String getFallbackName(RegisterArg arg) {
String name = arg.getName();
String base = "r" + arg.getRegNum();
if (name != null && !name.equals("this")) {
return base + "_" + name;
}
return base;
}
private static String makeNameForType(ArgType type) { private static String makeNameForType(ArgType type) {
if (type.isPrimitive()) { if (type.isPrimitive()) {
return makeNameForPrimitive(type); return makeNameForPrimitive(type);
......
...@@ -152,6 +152,18 @@ public class InsnNode extends LineAttrNode { ...@@ -152,6 +152,18 @@ public class InsnNode extends LineAttrNode {
} }
} }
public boolean isConstInsn() {
switch (getType()) {
case CONST:
case CONST_STR:
case CONST_CLASS:
return true;
default:
return false;
}
}
public boolean canReorder() { public boolean canReorder() {
switch (getType()) { switch (getType()) {
case CONST: case CONST:
......
...@@ -16,7 +16,6 @@ import jadx.core.dex.trycatch.CatchAttr; ...@@ -16,7 +16,6 @@ import jadx.core.dex.trycatch.CatchAttr;
import jadx.core.dex.trycatch.ExceptionHandler; import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.SplitterBlockAttr; import jadx.core.dex.trycatch.SplitterBlockAttr;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.EmptyBitSet;
import jadx.core.utils.exceptions.JadxRuntimeException; import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList; import java.util.ArrayList;
...@@ -28,6 +27,8 @@ import java.util.List; ...@@ -28,6 +27,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import static jadx.core.utils.EmptyBitSet.EMPTY;
public class BlockMakerVisitor extends AbstractVisitor { public class BlockMakerVisitor extends AbstractVisitor {
// leave these instructions alone in block node // leave these instructions alone in block node
...@@ -36,9 +37,8 @@ public class BlockMakerVisitor extends AbstractVisitor { ...@@ -36,9 +37,8 @@ public class BlockMakerVisitor extends AbstractVisitor {
InsnType.IF, InsnType.IF,
InsnType.SWITCH, InsnType.SWITCH,
InsnType.MONITOR_ENTER, InsnType.MONITOR_ENTER,
InsnType.MONITOR_EXIT); InsnType.MONITOR_EXIT
);
private static final BitSet EMPTY_BITSET = new EmptyBitSet();
@Override @Override
public void visit(MethodNode mth) { public void visit(MethodNode mth) {
...@@ -298,7 +298,7 @@ public class BlockMakerVisitor extends AbstractVisitor { ...@@ -298,7 +298,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
private static void computeDominanceFrontier(MethodNode mth) { private static void computeDominanceFrontier(MethodNode mth) {
for (BlockNode exit : mth.getExitBlocks()) { for (BlockNode exit : mth.getExitBlocks()) {
exit.setDomFrontier(EMPTY_BITSET); exit.setDomFrontier(EMPTY);
} }
for (BlockNode block : mth.getBasicBlocks()) { for (BlockNode block : mth.getBasicBlocks()) {
computeBlockDF(mth, block); computeBlockDF(mth, block);
...@@ -330,7 +330,7 @@ public class BlockMakerVisitor extends AbstractVisitor { ...@@ -330,7 +330,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
} }
} }
if (domFrontier == null || domFrontier.cardinality() == 0) { if (domFrontier == null || domFrontier.cardinality() == 0) {
domFrontier = EMPTY_BITSET; domFrontier = EMPTY;
} }
block.setDomFrontier(domFrontier); block.setDomFrontier(domFrontier);
} }
......
...@@ -13,6 +13,7 @@ import jadx.core.dex.nodes.BlockNode; ...@@ -13,6 +13,7 @@ import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.EmptyBitSet;
import jadx.core.utils.InsnList; import jadx.core.utils.InsnList;
import jadx.core.utils.exceptions.JadxRuntimeException; import jadx.core.utils.exceptions.JadxRuntimeException;
...@@ -97,7 +98,8 @@ public class CodeShrinker extends AbstractVisitor { ...@@ -97,7 +98,8 @@ public class CodeShrinker extends AbstractVisitor {
} }
private boolean canMove(int from, int to) { private boolean canMove(int from, int to) {
List<RegisterArg> movedArgs = argsList.get(from).getArgs(); ArgsInfo startInfo = argsList.get(from);
List<RegisterArg> movedArgs = startInfo.getArgs();
int start = from + 1; int start = from + 1;
if (start == to) { if (start == to) {
// previous instruction or on edge of inline border // previous instruction or on edge of inline border
...@@ -106,9 +108,17 @@ public class CodeShrinker extends AbstractVisitor { ...@@ -106,9 +108,17 @@ public class CodeShrinker extends AbstractVisitor {
if (start > to) { if (start > to) {
throw new JadxRuntimeException("Invalid inline insn positions: " + start + " - " + to); throw new JadxRuntimeException("Invalid inline insn positions: " + start + " - " + to);
} }
BitSet movedSet = new BitSet(); BitSet movedSet;
for (RegisterArg arg : movedArgs) { if (movedArgs.isEmpty()) {
movedSet.set(arg.getRegNum()); if (startInfo.insn.isConstInsn()) {
return true;
}
movedSet = EmptyBitSet.EMPTY;
} else {
movedSet = new BitSet();
for (RegisterArg arg : movedArgs) {
movedSet.set(arg.getRegNum());
}
} }
for (int i = start; i < to; i++) { for (int i = start; i < to; i++) {
ArgsInfo argsInfo = argsList.get(i); ArgsInfo argsInfo = argsList.get(i);
...@@ -188,6 +198,9 @@ public class CodeShrinker extends AbstractVisitor { ...@@ -188,6 +198,9 @@ public class CodeShrinker extends AbstractVisitor {
List<WrapInfo> wrapList = new ArrayList<WrapInfo>(); List<WrapInfo> wrapList = new ArrayList<WrapInfo>();
for (ArgsInfo argsInfo : argsList) { for (ArgsInfo argsInfo : argsList) {
List<RegisterArg> args = argsInfo.getArgs(); List<RegisterArg> args = argsInfo.getArgs();
if (args.isEmpty()) {
continue;
}
ListIterator<RegisterArg> it = args.listIterator(args.size()); ListIterator<RegisterArg> it = args.listIterator(args.size());
while (it.hasPrevious()) { while (it.hasPrevious()) {
RegisterArg arg = it.previous(); RegisterArg arg = it.previous();
...@@ -234,7 +247,7 @@ public class CodeShrinker extends AbstractVisitor { ...@@ -234,7 +247,7 @@ public class CodeShrinker extends AbstractVisitor {
} }
private static boolean canMoveBetweenBlocks(InsnNode assignInsn, BlockNode assignBlock, private static boolean canMoveBetweenBlocks(InsnNode assignInsn, BlockNode assignBlock,
BlockNode useBlock, InsnNode useInsn) { BlockNode useBlock, InsnNode useInsn) {
if (!BlockUtils.isPathExists(assignBlock, useBlock)) { if (!BlockUtils.isPathExists(assignBlock, useBlock)) {
return false; return false;
} }
......
...@@ -4,6 +4,7 @@ import jadx.core.dex.info.FieldInfo; ...@@ -4,6 +4,7 @@ import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.instructions.IndexInsnNode; import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode; import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.InvokeType;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.LiteralArg;
...@@ -12,7 +13,6 @@ import jadx.core.dex.instructions.args.SSAVar; ...@@ -12,7 +13,6 @@ import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InstructionRemover; import jadx.core.utils.InstructionRemover;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
...@@ -30,7 +30,7 @@ public class ConstInlinerVisitor extends AbstractVisitor { ...@@ -30,7 +30,7 @@ public class ConstInlinerVisitor extends AbstractVisitor {
for (BlockNode block : mth.getBasicBlocks()) { for (BlockNode block : mth.getBasicBlocks()) {
toRemove.clear(); toRemove.clear();
for (InsnNode insn : block.getInstructions()) { for (InsnNode insn : block.getInstructions()) {
if (checkInsn(mth, block, insn)) { if (checkInsn(mth, insn)) {
toRemove.add(insn); toRemove.add(insn);
} }
} }
...@@ -40,7 +40,7 @@ public class ConstInlinerVisitor extends AbstractVisitor { ...@@ -40,7 +40,7 @@ public class ConstInlinerVisitor extends AbstractVisitor {
} }
} }
private static boolean checkInsn(MethodNode mth, BlockNode block, InsnNode insn) { private static boolean checkInsn(MethodNode mth, InsnNode insn) {
if (insn.getType() != InsnType.CONST) { if (insn.getType() != InsnType.CONST) {
return false; return false;
} }
...@@ -48,15 +48,22 @@ public class ConstInlinerVisitor extends AbstractVisitor { ...@@ -48,15 +48,22 @@ public class ConstInlinerVisitor extends AbstractVisitor {
if (!arg.isLiteral()) { if (!arg.isLiteral()) {
return false; return false;
} }
long lit = ((LiteralArg) arg).getLiteral();
SSAVar sVar = insn.getResult().getSVar(); SSAVar sVar = insn.getResult().getSVar();
if (mth.getExceptionHandlersCount() != 0) { if (lit == 0) {
// don't inline null object if:
// - used as instance arg in invoke instruction
for (RegisterArg useArg : sVar.getUseList()) { for (RegisterArg useArg : sVar.getUseList()) {
InsnNode parentInsn = useArg.getParentInsn(); InsnNode parentInsn = useArg.getParentInsn();
if (parentInsn != null) { if (parentInsn != null) {
// TODO: speed up expensive operations InsnType insnType = parentInsn.getType();
BlockNode useBlock = BlockUtils.getBlockByInsn(mth, parentInsn); if (insnType == InsnType.INVOKE) {
if (useBlock == null || !BlockUtils.isCleanPathExists(block, useBlock)) { InvokeNode inv = (InvokeNode) parentInsn;
return false; if (inv.getInvokeType() != InvokeType.STATIC
&& inv.getArg(0) == useArg) {
return false;
}
} }
} }
} }
...@@ -66,7 +73,6 @@ public class ConstInlinerVisitor extends AbstractVisitor { ...@@ -66,7 +73,6 @@ public class ConstInlinerVisitor extends AbstractVisitor {
if (!arg.getType().isTypeKnown()) { if (!arg.getType().isTypeKnown()) {
arg.merge(resType); arg.merge(resType);
} }
long lit = ((LiteralArg) arg).getLiteral();
return replaceConst(mth, sVar, lit); return replaceConst(mth, sVar, lit);
} }
...@@ -85,6 +91,10 @@ public class ConstInlinerVisitor extends AbstractVisitor { ...@@ -85,6 +91,10 @@ public class ConstInlinerVisitor extends AbstractVisitor {
if (use.size() == 1 || arg.isTypeImmutable()) { if (use.size() == 1 || arg.isTypeImmutable()) {
// arg used only in one place // arg used only in one place
litArg = InsnArg.lit(literal, arg.getType()); litArg = InsnArg.lit(literal, arg.getType());
} else if (useInsn.getType() == InsnType.MOVE
&& !useInsn.getResult().getType().isTypeKnown()) {
// save type for 'move' instructions (hard to find type in chains of 'move')
litArg = InsnArg.lit(literal, arg.getType());
} else { } else {
// in most cases type not equal arg.getType() // in most cases type not equal arg.getType()
// just set unknown type and run type fixer // just set unknown type and run type fixer
......
...@@ -9,6 +9,7 @@ import jadx.core.dex.nodes.IRegion; ...@@ -9,6 +9,7 @@ import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.SwitchRegion; import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.TryCatchRegion;
import jadx.core.dex.regions.conditions.IfRegion; import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.LoopRegion; import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
...@@ -72,7 +73,8 @@ public class ReturnVisitor extends AbstractVisitor { ...@@ -72,7 +73,8 @@ public class ReturnVisitor extends AbstractVisitor {
for (IRegion region : regionStack) { for (IRegion region : regionStack) {
// ignore paths on other branches // ignore paths on other branches
if (region instanceof IfRegion if (region instanceof IfRegion
|| region instanceof SwitchRegion) { || region instanceof SwitchRegion
|| region instanceof TryCatchRegion) {
curContainer = region; curContainer = region;
continue; continue;
} }
......
...@@ -45,7 +45,6 @@ public class SSATransform extends AbstractVisitor { ...@@ -45,7 +45,6 @@ public class SSATransform extends AbstractVisitor {
if (removeUselessPhi(mth)) { if (removeUselessPhi(mth)) {
renameVariables(mth); renameVariables(mth);
} }
} }
private static void placePhi(MethodNode mth, int regNum, LiveVarAnalysis la) { private static void placePhi(MethodNode mth, int regNum, LiveVarAnalysis la) {
......
...@@ -2,10 +2,12 @@ package jadx.core.utils; ...@@ -2,10 +2,12 @@ package jadx.core.utils;
import java.util.BitSet; import java.util.BitSet;
public class EmptyBitSet extends BitSet { public final class EmptyBitSet extends BitSet {
private static final long serialVersionUID = -1194884945157778639L; private static final long serialVersionUID = -1194884945157778639L;
public static final BitSet EMPTY = new EmptyBitSet();
public EmptyBitSet() { public EmptyBitSet() {
super(0); super(0);
} }
...@@ -62,7 +64,7 @@ public class EmptyBitSet extends BitSet { ...@@ -62,7 +64,7 @@ public class EmptyBitSet extends BitSet {
@Override @Override
public BitSet get(int fromIndex, int toIndex) { public BitSet get(int fromIndex, int toIndex) {
throw new UnsupportedOperationException(); return EMPTY;
} }
@Override @Override
...@@ -84,4 +86,9 @@ public class EmptyBitSet extends BitSet { ...@@ -84,4 +86,9 @@ public class EmptyBitSet extends BitSet {
public void andNot(BitSet set) { public void andNot(BitSet set) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public Object clone() {
return this;
}
} }
...@@ -82,7 +82,9 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -82,7 +82,9 @@ public abstract class IntegrationTest extends TestUtils {
} }
// don't unload class // don't unload class
System.out.println("-----------------------------------------------------------");
System.out.println(cls.getCode()); System.out.println(cls.getCode());
System.out.println("-----------------------------------------------------------");
checkCode(cls); checkCode(cls);
compile(cls); compile(cls);
...@@ -163,6 +165,7 @@ public abstract class IntegrationTest extends TestUtils { ...@@ -163,6 +165,7 @@ public abstract class IntegrationTest extends TestUtils {
} catch (InvocationTargetException ie) { } catch (InvocationTargetException ie) {
rethrow("Decompiled check failed", ie); rethrow("Decompiled check failed", ie);
} }
System.out.println("Auto check: PASSED");
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
fail("Auto check exception: " + e.getMessage()); fail("Auto check exception: " + e.getMessage());
......
package jadx.tests.integration.invoke;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import java.io.IOException;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestInvokeInCatch extends IntegrationTest {
public static class TestCls {
private static final String TAG = "TAG";
private void test(int[] a, int b) {
try {
exc();
} catch (IOException e) {
if (b == 1) {
log(TAG, "Error: {}", e.getMessage());
}
}
}
private static void log(String tag, String str, String... args) {
}
private void exc() throws IOException {
throw new IOException();
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("try {"));
assertThat(code, containsOne("exc();"));
assertThat(code, not(containsString("return;")));
assertThat(code, containsOne("} catch (IOException e) {"));
assertThat(code, containsOne("if (b == 1) {"));
// assertThat(code, containsOne("log(TAG, \"Error: {}\", e.getMessage());"));
assertThat(code, containsOne("log(TAG, \"Error: {}\", new String[]{e.getMessage()});"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment