Commit 46d3992b authored by Skylot's avatar Skylot

core: fix 'finally' extract (fix #53 and #54)

parent 164123f5
...@@ -4,7 +4,7 @@ import jadx.api.IJadxArgs; ...@@ -4,7 +4,7 @@ import jadx.api.IJadxArgs;
import jadx.core.codegen.CodeGen; import jadx.core.codegen.CodeGen;
import jadx.core.dex.visitors.ClassModifier; import jadx.core.dex.visitors.ClassModifier;
import jadx.core.dex.visitors.CodeShrinker; import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.dex.visitors.ConstInlinerVisitor; import jadx.core.dex.visitors.ConstInlineVisitor;
import jadx.core.dex.visitors.DebugInfoVisitor; import jadx.core.dex.visitors.DebugInfoVisitor;
import jadx.core.dex.visitors.DotGraphVisitor; import jadx.core.dex.visitors.DotGraphVisitor;
import jadx.core.dex.visitors.EnumVisitor; import jadx.core.dex.visitors.EnumVisitor;
...@@ -73,7 +73,7 @@ public class Jadx { ...@@ -73,7 +73,7 @@ public class Jadx {
passes.add(DotGraphVisitor.dumpRaw(outDir)); passes.add(DotGraphVisitor.dumpRaw(outDir));
} }
passes.add(new ConstInlinerVisitor()); passes.add(new ConstInlineVisitor());
passes.add(new FinishTypeInference()); passes.add(new FinishTypeInference());
passes.add(new EliminatePhiNodes()); passes.add(new EliminatePhiNodes());
......
...@@ -442,11 +442,6 @@ public class InsnGen { ...@@ -442,11 +442,6 @@ public class InsnGen {
addArg(code, insn.getArg(0)); addArg(code, insn.getArg(0));
break; break;
case PHI:
assert isFallback();
code.add("PHI(").add(String.valueOf(insn.getArgsCount())).add(")");
break;
/* fallback mode instructions */ /* fallback mode instructions */
case IF: case IF:
assert isFallback() : "if insn in not fallback mode"; assert isFallback() : "if insn in not fallback mode";
......
...@@ -4,35 +4,93 @@ import jadx.core.dex.attributes.AFlag; ...@@ -4,35 +4,93 @@ import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InstructionRemover;
import jadx.core.utils.Utils; import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class PhiInsn extends InsnNode { import java.util.IdentityHashMap;
import java.util.Map;
import org.jetbrains.annotations.NotNull;
public final class PhiInsn extends InsnNode {
private final Map<RegisterArg, BlockNode> blockBinds;
public PhiInsn(int regNum, int predecessors) { public PhiInsn(int regNum, int predecessors) {
super(InsnType.PHI, predecessors); super(InsnType.PHI, predecessors);
this.blockBinds = new IdentityHashMap<RegisterArg, BlockNode>(predecessors);
setResult(InsnArg.reg(regNum, ArgType.UNKNOWN)); setResult(InsnArg.reg(regNum, ArgType.UNKNOWN));
for (int i = 0; i < predecessors; i++) {
addReg(regNum, ArgType.UNKNOWN);
}
add(AFlag.DONT_INLINE); add(AFlag.DONT_INLINE);
} }
public RegisterArg bindArg(BlockNode pred) {
RegisterArg arg = InsnArg.reg(getResult().getRegNum(), getResult().getType());
bindArg(arg, pred);
return arg;
}
public void bindArg(RegisterArg arg, BlockNode pred) {
if (blockBinds.containsValue(pred)) {
throw new JadxRuntimeException("Duplicate predecessors in PHI insn: " + pred + ", " + this);
}
addArg(arg);
blockBinds.put(arg, pred);
}
public BlockNode getBlockByArg(RegisterArg arg) {
return blockBinds.get(arg);
}
public Map<RegisterArg, BlockNode> getBlockBinds() {
return blockBinds;
}
@Override @Override
@NotNull
public RegisterArg getArg(int n) { public RegisterArg getArg(int n) {
return (RegisterArg) super.getArg(n); return (RegisterArg) super.getArg(n);
} }
public boolean removeArg(RegisterArg arg) { @Override
boolean isRemoved = super.removeArg(arg); public boolean removeArg(InsnArg arg) {
if (isRemoved) { if (!(arg instanceof RegisterArg)) {
arg.getSVar().setUsedInPhi(null); return false;
}
RegisterArg reg = (RegisterArg) arg;
if (super.removeArg(reg)) {
blockBinds.remove(reg);
InstructionRemover.fixUsedInPhiFlag(reg);
return true;
}
return false;
}
@Override
public boolean replaceArg(InsnArg from, InsnArg to) {
if (!(from instanceof RegisterArg) || !(to instanceof RegisterArg)) {
return false;
}
BlockNode pred = getBlockByArg((RegisterArg) from);
if (pred == null) {
throw new JadxRuntimeException("Unknown predecessor block by arg " + from + " in PHI: " + this);
}
if (removeArg(from)) {
bindArg((RegisterArg) to, pred);
} }
return isRemoved; return true;
}
@Override
public void setArg(int n, InsnArg arg) {
throw new JadxRuntimeException("Unsupported operation for PHI node");
} }
@Override @Override
public String toString() { public String toString() {
return "PHI: " + getResult() + " = " + Utils.listToString(getArguments()); return "PHI: " + getResult() + " = " + Utils.listToString(getArguments())
+ " binds: " + blockBinds;
} }
} }
...@@ -12,6 +12,7 @@ import jadx.core.dex.nodes.FieldNode; ...@@ -12,6 +12,7 @@ import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.parser.FieldValueAttr; import jadx.core.dex.nodes.parser.FieldValueAttr;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -44,7 +45,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -44,7 +45,7 @@ public class RegisterArg extends InsnArg implements Named {
return sVar; return sVar;
} }
void setSVar(SSAVar sVar) { void setSVar(@NotNull SSAVar sVar) {
this.sVar = sVar; this.sVar = sVar;
} }
...@@ -162,7 +163,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -162,7 +163,7 @@ public class RegisterArg extends InsnArg implements Named {
@Override @Override
public int hashCode() { public int hashCode() {
return (regNum * 31 + type.hashCode()) * 31 + (sVar != null ? sVar.hashCode() : 0); return regNum * 31 + type.hashCode();
} }
@Override @Override
......
...@@ -111,6 +111,10 @@ public class InsnNode extends LineAttrNode { ...@@ -111,6 +111,10 @@ public class InsnNode extends LineAttrNode {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
if (arg == arguments.get(i)) { if (arg == arguments.get(i)) {
arguments.remove(i); arguments.remove(i);
if (arg instanceof RegisterArg) {
RegisterArg reg = (RegisterArg) arg;
reg.getSVar().removeUse(reg);
}
return true; return true;
} }
} }
......
...@@ -64,11 +64,26 @@ public class TryCatchBlock { ...@@ -64,11 +64,26 @@ public class TryCatchBlock {
private void unbindHandler(ExceptionHandler handler) { private void unbindHandler(ExceptionHandler handler) {
for (BlockNode block : handler.getBlocks()) { for (BlockNode block : handler.getBlocks()) {
block.add(AFlag.SKIP); block.add(AFlag.SKIP);
ExcHandlerAttr excHandlerAttr = block.get(AType.EXC_HANDLER);
if (excHandlerAttr != null) {
if (excHandlerAttr.getHandler().equals(handler)) {
block.remove(AType.EXC_HANDLER);
}
}
SplitterBlockAttr splitter = handler.getHandlerBlock().get(AType.SPLITTER_BLOCK);
if (splitter != null) {
splitter.getBlock().remove(AType.SPLITTER_BLOCK);
}
} }
} }
private void removeWholeBlock(MethodNode mth) { private void removeWholeBlock(MethodNode mth) {
// self destruction // self destruction
for (Iterator<ExceptionHandler> it = handlers.iterator(); it.hasNext(); ) {
ExceptionHandler h = it.next();
unbindHandler(h);
it.remove();
}
for (InsnNode insn : insns) { for (InsnNode insn : insns) {
insn.removeAttr(attr); insn.removeAttr(attr);
} }
...@@ -83,9 +98,22 @@ public class TryCatchBlock { ...@@ -83,9 +98,22 @@ public class TryCatchBlock {
insn.addAttr(attr); insn.addAttr(attr);
} }
public void removeInsn(InsnNode insn) { public void removeInsn(MethodNode mth, InsnNode insn) {
insns.remove(insn); insns.remove(insn);
insn.remove(AType.CATCH_BLOCK); insn.remove(AType.CATCH_BLOCK);
if (insns.isEmpty()) {
removeWholeBlock(mth);
}
}
public void removeBlock(MethodNode mth, BlockNode block) {
for (InsnNode insn : block.getInstructions()) {
insns.remove(insn);
insn.remove(AType.CATCH_BLOCK);
}
if (insns.isEmpty()) {
removeWholeBlock(mth);
}
} }
public Iterable<InsnNode> getInsns() { public Iterable<InsnNode> getInsns() {
......
...@@ -23,7 +23,7 @@ import jadx.core.utils.exceptions.JadxException; ...@@ -23,7 +23,7 @@ import jadx.core.utils.exceptions.JadxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
public class ConstInlinerVisitor extends AbstractVisitor { public class ConstInlineVisitor extends AbstractVisitor {
@Override @Override
public void visit(MethodNode mth) throws JadxException { public void visit(MethodNode mth) throws JadxException {
...@@ -38,14 +38,12 @@ public class ConstInlinerVisitor extends AbstractVisitor { ...@@ -38,14 +38,12 @@ public class ConstInlinerVisitor extends AbstractVisitor {
toRemove.add(insn); toRemove.add(insn);
} }
} }
if (!toRemove.isEmpty()) {
InstructionRemover.removeAll(mth, block, toRemove); InstructionRemover.removeAll(mth, block, toRemove);
} }
} }
}
private static boolean checkInsn(MethodNode mth, InsnNode insn) { private static boolean checkInsn(MethodNode mth, InsnNode insn) {
if (insn.getType() != InsnType.CONST) { if (insn.getType() != InsnType.CONST || insn.contains(AFlag.DONT_INLINE)) {
return false; return false;
} }
InsnArg arg = insn.getArg(0); InsnArg arg = insn.getArg(0);
......
...@@ -33,7 +33,7 @@ public class FallbackModeVisitor extends AbstractVisitor { ...@@ -33,7 +33,7 @@ public class FallbackModeVisitor extends AbstractVisitor {
case CONST_CLASS: case CONST_CLASS:
case CMP_L: case CMP_L:
case CMP_G: case CMP_G:
catchAttr.getTryBlock().removeInsn(insn); catchAttr.getTryBlock().removeInsn(mth, insn);
break; break;
default: default:
......
...@@ -9,7 +9,7 @@ import jadx.core.dex.instructions.args.RegisterArg; ...@@ -9,7 +9,7 @@ import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.trycatch.ExcHandlerAttr; import jadx.core.dex.trycatch.CatchAttr;
import jadx.core.dex.trycatch.ExceptionHandler; import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.SplitterBlockAttr; import jadx.core.dex.trycatch.SplitterBlockAttr;
import jadx.core.dex.trycatch.TryCatchBlock; import jadx.core.dex.trycatch.TryCatchBlock;
...@@ -18,7 +18,6 @@ import jadx.core.dex.visitors.blocksmaker.helpers.BlocksPair; ...@@ -18,7 +18,6 @@ import jadx.core.dex.visitors.blocksmaker.helpers.BlocksPair;
import jadx.core.dex.visitors.blocksmaker.helpers.BlocksRemoveInfo; import jadx.core.dex.visitors.blocksmaker.helpers.BlocksRemoveInfo;
import jadx.core.dex.visitors.ssa.LiveVarAnalysis; import jadx.core.dex.visitors.ssa.LiveVarAnalysis;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.InstructionRemover;
import jadx.core.utils.exceptions.JadxRuntimeException; import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList; import java.util.ArrayList;
...@@ -35,7 +34,6 @@ import org.slf4j.Logger; ...@@ -35,7 +34,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.connect; import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.connect;
import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.insertBlockBetween;
import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.removeConnection; import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.removeConnection;
public class BlockFinallyExtract extends AbstractVisitor { public class BlockFinallyExtract extends AbstractVisitor {
...@@ -48,10 +46,8 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -48,10 +46,8 @@ public class BlockFinallyExtract extends AbstractVisitor {
} }
boolean reloadBlocks = false; boolean reloadBlocks = false;
List<BlockNode> basicBlocks = mth.getBasicBlocks(); for (ExceptionHandler excHandler : mth.getExceptionHandlers()) {
for (int i = 0; i < basicBlocks.size(); i++) { if (processExceptionHandler(mth, excHandler)) {
BlockNode block = basicBlocks.get(i);
if (processExceptionHandler(mth, block)) {
reloadBlocks = true; reloadBlocks = true;
} }
} }
...@@ -61,13 +57,7 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -61,13 +57,7 @@ public class BlockFinallyExtract extends AbstractVisitor {
} }
} }
private static boolean processExceptionHandler(MethodNode mth, BlockNode block) { private static boolean processExceptionHandler(MethodNode mth, ExceptionHandler excHandler) {
ExcHandlerAttr handlerAttr = block.get(AType.EXC_HANDLER);
if (handlerAttr == null) {
return false;
}
ExceptionHandler excHandler = handlerAttr.getHandler();
// check if handler has exit edge to block not from this handler // check if handler has exit edge to block not from this handler
boolean noExitNode = true; boolean noExitNode = true;
boolean reThrowRemoved = false; boolean reThrowRemoved = false;
...@@ -82,16 +72,16 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -82,16 +72,16 @@ public class BlockFinallyExtract extends AbstractVisitor {
&& size != 0 && size != 0
&& insns.get(size - 1).getType() == InsnType.THROW) { && insns.get(size - 1).getType() == InsnType.THROW) {
reThrowRemoved = true; reThrowRemoved = true;
InstructionRemover.remove(mth, excBlock, size - 1); insns.remove(size - 1);
} }
} }
if (reThrowRemoved && noExitNode if (reThrowRemoved && noExitNode
&& extractFinally(mth, block, excHandler)) { && extractFinally(mth, excHandler)) {
return true; return true;
} }
int totalSize = countInstructions(excHandler); int totalSize = countInstructions(excHandler);
if (totalSize == 0 && reThrowRemoved && noExitNode) { if (totalSize == 0 && reThrowRemoved && noExitNode) {
handlerAttr.getTryBlock().removeHandler(mth, excHandler); excHandler.getTryBlock().removeHandler(mth, excHandler);
} }
return false; return false;
} }
...@@ -99,7 +89,7 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -99,7 +89,7 @@ public class BlockFinallyExtract extends AbstractVisitor {
/** /**
* Search and remove common code from 'catch' and 'handlers'. * Search and remove common code from 'catch' and 'handlers'.
*/ */
private static boolean extractFinally(MethodNode mth, BlockNode handlerBlock, ExceptionHandler handler) { private static boolean extractFinally(MethodNode mth, ExceptionHandler handler) {
int count = handler.getBlocks().size(); int count = handler.getBlocks().size();
BitSet bs = new BitSet(count); BitSet bs = new BitSet(count);
List<BlockNode> blocks = new ArrayList<BlockNode>(count); List<BlockNode> blocks = new ArrayList<BlockNode>(count);
...@@ -171,21 +161,105 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -171,21 +161,105 @@ public class BlockFinallyExtract extends AbstractVisitor {
return false; return false;
} }
// 'finally' extract confirmed /* 'finally' extract confirmed, run remove steps */
LiveVarAnalysis laBefore = null;
boolean runReMap = isReMapNeeded(removes);
if (runReMap) {
laBefore = new LiveVarAnalysis(mth);
laBefore.runAnalysis();
}
for (BlocksRemoveInfo removeInfo : removes) { for (BlocksRemoveInfo removeInfo : removes) {
if (!applyRemove(mth, removeInfo)) { if (!applyRemove(mth, removeInfo)) {
return false; return false;
} }
} }
handler.setFinally(true);
LiveVarAnalysis laAfter = null;
// remove 'move-exception' instruction // remove 'move-exception' instruction
if (BlockUtils.checkLastInsnType(handlerBlock, InsnType.MOVE_EXCEPTION)) { BlockNode handlerBlock = handler.getHandlerBlock();
InstructionRemover.remove(mth, handlerBlock, handlerBlock.getInstructions().size() - 1); InsnNode me = BlockUtils.getLastInsn(handlerBlock);
if (me != null && me.getType() == InsnType.MOVE_EXCEPTION) {
boolean replaced = false;
List<InsnNode> insnsList = handlerBlock.getInstructions();
if (!handlerBlock.getCleanSuccessors().isEmpty()) {
laAfter = new LiveVarAnalysis(mth);
laAfter.runAnalysis();
RegisterArg resArg = me.getResult();
BlockNode succ = handlerBlock.getCleanSuccessors().get(0);
if (laAfter.isLive(succ.getId(), resArg.getRegNum())) {
// kill variable
InsnNode kill = new InsnNode(InsnType.NOP, 0);
kill.setResult(resArg);
kill.add(AFlag.REMOVE);
insnsList.set(insnsList.size() - 1, kill);
replaced = true;
}
}
if (!replaced) {
insnsList.remove(insnsList.size() - 1);
handlerBlock.add(AFlag.SKIP); handlerBlock.add(AFlag.SKIP);
} }
}
// generate 'move' instruction for mapped register pairs
if (runReMap) {
if (laAfter == null) {
laAfter = new LiveVarAnalysis(mth);
laAfter.runAnalysis();
}
performVariablesReMap(mth, removes, laBefore, laAfter);
}
handler.setFinally(true);
return true; return true;
} }
private static void performVariablesReMap(MethodNode mth, List<BlocksRemoveInfo> removes,
LiveVarAnalysis laBefore, LiveVarAnalysis laAfter) {
BitSet processed = new BitSet(mth.getRegsCount());
for (BlocksRemoveInfo removeInfo : removes) {
processed.clear();
BlockNode insertBlock = removeInfo.getStart().getSecond();
if (removeInfo.getRegMap().isEmpty() || insertBlock == null) {
continue;
}
for (Map.Entry<RegisterArg, RegisterArg> entry : removeInfo.getRegMap().entrySet()) {
RegisterArg from = entry.getKey();
int regNum = from.getRegNum();
if (!processed.get(regNum)) {
if (laBefore.isLive(insertBlock.getId(), regNum)) {
// remap variable
RegisterArg to = entry.getValue();
InsnNode move = new InsnNode(InsnType.MOVE, 1);
move.setResult(to);
move.addArg(from);
insertBlock.getInstructions().add(move);
} else if (laAfter.isLive(insertBlock.getId(), regNum)) {
// kill variable
InsnNode kill = new InsnNode(InsnType.NOP, 0);
kill.setResult(from);
kill.add(AFlag.REMOVE);
insertBlock.getInstructions().add(0, kill);
}
processed.set(regNum);
}
}
}
}
private static boolean isReMapNeeded(List<BlocksRemoveInfo> removes) {
for (BlocksRemoveInfo removeInfo : removes) {
if (!removeInfo.getRegMap().isEmpty()) {
return true;
}
}
return false;
}
private static BlocksRemoveInfo removeInsns(MethodNode mth, BlockNode remBlock, List<BlockNode> blocks, BitSet bs) { private static BlocksRemoveInfo removeInsns(MethodNode mth, BlockNode remBlock, List<BlockNode> blocks, BitSet bs) {
if (blocks.isEmpty()) { if (blocks.isEmpty()) {
return null; return null;
...@@ -223,14 +297,36 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -223,14 +297,36 @@ public class BlockFinallyExtract extends AbstractVisitor {
return null; return null;
} }
// first - fast check // first - fast check
int delta = remInsns.size() - startInsns.size(); int startPos = remInsns.size() - startInsns.size();
if (!checkInsns(remInsns, startInsns, delta, null)) { int endPos = 0;
if (!checkInsns(remInsns, startInsns, startPos, null)) {
if (checkInsns(remInsns, startInsns, 0, null)) {
startPos = 0;
endPos = startInsns.size();
} else {
boolean found = false;
for (int i = 1; i < startPos; i++) {
if (checkInsns(remInsns, startInsns, i, null)) {
startPos = i;
endPos = startInsns.size() + i;
found = true;
break;
}
}
if (!found) {
return null; return null;
} }
BlocksRemoveInfo removeInfo = new BlocksRemoveInfo(new BlocksPair(remBlock, startBlock)); }
removeInfo.setStartSplitIndex(delta); }
BlocksPair startPair = new BlocksPair(remBlock, startBlock);
BlocksRemoveInfo removeInfo = new BlocksRemoveInfo(startPair);
removeInfo.setStartSplitIndex(startPos);
removeInfo.setEndSplitIndex(endPos);
if (endPos != 0) {
removeInfo.setEnd(startPair);
}
// second - run checks again for collect registers mapping // second - run checks again for collect registers mapping
if (!checkInsns(remInsns, startInsns, delta, removeInfo)) { if (!checkInsns(remInsns, startInsns, startPos, removeInfo)) {
return null; return null;
} }
return removeInfo; return removeInfo;
...@@ -255,18 +351,23 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -255,18 +351,23 @@ public class BlockFinallyExtract extends AbstractVisitor {
&& !sameBlocks(remBlock, startBlock, removeInfo)) { && !sameBlocks(remBlock, startBlock, removeInfo)) {
return false; return false;
} }
removeInfo.getProcessed().add(new BlocksPair(remBlock, startBlock)); BlocksPair currentPair = new BlocksPair(remBlock, startBlock);
removeInfo.getProcessed().add(currentPair);
List<BlockNode> baseCS = startBlock.getCleanSuccessors(); List<BlockNode> baseCS = startBlock.getCleanSuccessors();
List<BlockNode> remCS = remBlock.getCleanSuccessors(); List<BlockNode> remCS = remBlock.getCleanSuccessors();
if (baseCS.size() != remCS.size()) { if (baseCS.size() != remCS.size()) {
removeInfo.getOuts().add(new BlocksPair(remBlock, startBlock)); removeInfo.getOuts().add(currentPair);
return true; return true;
} }
for (int i = 0; i < baseCS.size(); i++) { for (int i = 0; i < baseCS.size(); i++) {
BlockNode sBlock = baseCS.get(i); BlockNode sBlock = baseCS.get(i);
BlockNode rBlock = remCS.get(i); BlockNode rBlock = remCS.get(i);
if (bs.get(sBlock.getId())) { if (bs.get(sBlock.getId())) {
if (removeInfo.getEndSplitIndex() != 0) {
// end block is not correct
return false;
}
if (!checkBlocksTree(rBlock, sBlock, removeInfo, bs)) { if (!checkBlocksTree(rBlock, sBlock, removeInfo, bs)) {
return false; return false;
} }
...@@ -277,18 +378,22 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -277,18 +378,22 @@ public class BlockFinallyExtract extends AbstractVisitor {
return true; return true;
} }
private static boolean sameBlocks(BlockNode remBlock, BlockNode startBlock, BlocksRemoveInfo removeInfo) { private static boolean sameBlocks(BlockNode remBlock, BlockNode finallyBlock, BlocksRemoveInfo removeInfo) {
List<InsnNode> first = remBlock.getInstructions(); List<InsnNode> first = remBlock.getInstructions();
List<InsnNode> second = startBlock.getInstructions(); List<InsnNode> second = finallyBlock.getInstructions();
if (first.size() != second.size()) { if (first.size() < second.size()) {
return false; return false;
} }
int size = first.size(); int size = second.size();
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
if (!sameInsns(first.get(i), second.get(i), removeInfo)) { if (!sameInsns(first.get(i), second.get(i), removeInfo)) {
return false; return false;
} }
} }
if (first.size() > second.size()) {
removeInfo.setEndSplitIndex(second.size());
removeInfo.setEnd(new BlocksPair(remBlock, finallyBlock));
}
return true; return true;
} }
...@@ -332,27 +437,43 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -332,27 +437,43 @@ public class BlockFinallyExtract extends AbstractVisitor {
LOG.warn("Finally extract failed: remBlock pred: {}, {}, method: {}", remBlock, remBlock.getPredecessors(), mth); LOG.warn("Finally extract failed: remBlock pred: {}, {}, method: {}", remBlock, remBlock.getPredecessors(), mth);
return false; return false;
} }
BlockNode remBlockPred = remBlock.getPredecessors().get(0); BlockNode remBlockPred = remBlock.getPredecessors().get(0);
int splitIndex = removeInfo.getStartSplitIndex(); removeInfo.setStartPredecessor(remBlockPred);
if (splitIndex > 0) {
// split start block (remBlock) int startSplitIndex = removeInfo.getStartSplitIndex();
BlockNode newBlock = insertBlockBetween(mth, remBlockPred, remBlock); int endSplitIndex = removeInfo.getEndSplitIndex();
for (int i = 0; i < splitIndex; i++) { if (removeInfo.getStart().equals(removeInfo.getEnd())) {
InsnNode insnNode = remBlock.getInstructions().get(i); removeInfo.setEndSplitIndex(endSplitIndex - startSplitIndex);
insnNode.add(AFlag.SKIP);
newBlock.getInstructions().add(insnNode);
} }
Iterator<InsnNode> it = remBlock.getInstructions().iterator(); // split start block (remBlock)
if (startSplitIndex > 0) {
remBlock = splitBlock(mth, remBlock, startSplitIndex);
// change start block in removeInfo
removeInfo.getProcessed().remove(removeInfo.getStart());
BlocksPair newStart = new BlocksPair(remBlock, startBlock);
removeInfo.setStart(newStart);
removeInfo.getProcessed().add(newStart);
}
// split end block
if (endSplitIndex > 0) {
BlocksPair end = removeInfo.getEnd();
BlockNode newOut = splitBlock(mth, end.getFirst(), endSplitIndex);
for (BlockNode s : newOut.getSuccessors()) {
BlocksPair replaceOut = null;
Iterator<BlocksPair> it = removeInfo.getOuts().iterator();
while (it.hasNext()) { while (it.hasNext()) {
InsnNode insnNode = it.next(); BlocksPair outPair = it.next();
if (insnNode.contains(AFlag.SKIP)) { if (outPair.getFirst().equals(s)) {
it.remove(); it.remove();
replaceOut = new BlocksPair(newOut, outPair.getSecond());
break;
} }
} }
for (InsnNode insnNode : newBlock.getInstructions()) { if (replaceOut != null) {
insnNode.remove(AFlag.SKIP); removeInfo.getOuts().add(replaceOut);
}
} }
remBlockPred = newBlock;
} }
BlocksPair out = removeInfo.getOuts().iterator().next(); BlocksPair out = removeInfo.getOuts().iterator().next();
...@@ -377,8 +498,8 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -377,8 +498,8 @@ public class BlockFinallyExtract extends AbstractVisitor {
BlockNode pred = filtPreds.get(0); BlockNode pred = filtPreds.get(0);
BlockNode repl = removeInfo.getBySecond(pred); BlockNode repl = removeInfo.getBySecond(pred);
if (repl == null) { if (repl == null) {
throw new JadxRuntimeException("Block not found by " + pred LOG.error("Block not found by {}, in {}, method: {}", pred, removeInfo, mth);
+ ", in " + removeInfo + ", method: " + mth); return false;
} }
removeConnection(pred, rOut); removeConnection(pred, rOut);
addIgnoredEdge(repl, rOut); addIgnoredEdge(repl, rOut);
...@@ -396,39 +517,56 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -396,39 +517,56 @@ public class BlockFinallyExtract extends AbstractVisitor {
connect(pred, rOut); connect(pred, rOut);
} }
// generate 'move' instruction for mapped register pairs
if (!removeInfo.getRegMap().isEmpty()) {
// TODO: very expensive operation
LiveVarAnalysis la = new LiveVarAnalysis(mth);
la.runAnalysis();
for (Map.Entry<RegisterArg, RegisterArg> entry : removeInfo.getRegMap().entrySet()) {
RegisterArg from = entry.getKey();
if (la.isLive(remBlockPred.getId(), from.getRegNum())) {
RegisterArg to = entry.getValue();
InsnNode move = new InsnNode(InsnType.MOVE, 1);
move.setResult(to);
move.addArg(from);
remBlockPred.getInstructions().add(move);
}
}
}
// mark blocks for remove // mark blocks for remove
markForRemove(remBlock); markForRemove(mth, remBlock);
for (BlocksPair pair : removeInfo.getProcessed()) { for (BlocksPair pair : removeInfo.getProcessed()) {
markForRemove(pair.getFirst()); markForRemove(mth, pair.getFirst());
BlockNode second = pair.getSecond(); BlockNode second = pair.getSecond();
second.updateCleanSuccessors(); second.updateCleanSuccessors();
} }
return true; return true;
} }
private static BlockNode splitBlock(MethodNode mth, BlockNode block, int splitIndex) {
BlockNode newBlock = BlockSplitter.startNewBlock(mth, -1);
newBlock.getSuccessors().addAll(block.getSuccessors());
for (BlockNode s : new ArrayList<BlockNode>(block.getSuccessors())) {
removeConnection(block, s);
connect(newBlock, s);
}
block.getSuccessors().clear();
connect(block, newBlock);
block.updateCleanSuccessors();
newBlock.updateCleanSuccessors();
List<InsnNode> insns = block.getInstructions();
int size = insns.size();
for (int i = splitIndex; i < size; i++) {
InsnNode insnNode = insns.get(i);
insnNode.add(AFlag.SKIP);
newBlock.getInstructions().add(insnNode);
}
Iterator<InsnNode> it = insns.iterator();
while (it.hasNext()) {
InsnNode insnNode = it.next();
if (insnNode.contains(AFlag.SKIP)) {
it.remove();
}
}
for (InsnNode insnNode : newBlock.getInstructions()) {
insnNode.remove(AFlag.SKIP);
}
return newBlock;
}
/** /**
* Unbind block for removing. * Unbind block for removing.
*/ */
private static void markForRemove(BlockNode block) { private static void markForRemove(MethodNode mth, BlockNode block) {
for (BlockNode p : block.getPredecessors()) { for (BlockNode p : block.getPredecessors()) {
p.getSuccessors().remove(block); p.getSuccessors().remove(block);
p.updateCleanSuccessors();
} }
for (BlockNode s : block.getSuccessors()) { for (BlockNode s : block.getSuccessors()) {
s.getPredecessors().remove(block); s.getPredecessors().remove(block);
...@@ -436,6 +574,17 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -436,6 +574,17 @@ public class BlockFinallyExtract extends AbstractVisitor {
block.getPredecessors().clear(); block.getPredecessors().clear();
block.getSuccessors().clear(); block.getSuccessors().clear();
block.add(AFlag.REMOVE); block.add(AFlag.REMOVE);
block.remove(AFlag.SKIP);
CatchAttr catchAttr = block.get(AType.CATCH_BLOCK);
if (catchAttr != null) {
catchAttr.getTryBlock().removeBlock(mth, block);
for (BlockNode skipBlock : mth.getBasicBlocks()) {
if (skipBlock.contains(AFlag.SKIP)) {
markForRemove(mth, skipBlock);
}
}
}
} }
private static void addIgnoredEdge(BlockNode from, BlockNode toBlock) { private static void addIgnoredEdge(BlockNode from, BlockNode toBlock) {
...@@ -500,7 +649,7 @@ public class BlockFinallyExtract extends AbstractVisitor { ...@@ -500,7 +649,7 @@ public class BlockFinallyExtract extends AbstractVisitor {
for (BlockNode remPred : mb.getPredecessors()) { for (BlockNode remPred : mb.getPredecessors()) {
connect(remPred, origReturnBlock); connect(remPred, origReturnBlock);
} }
markForRemove(mb); markForRemove(mth, mb);
edgeAttr.getBlocks().remove(mb); edgeAttr.getBlocks().remove(mb);
} }
} }
......
...@@ -10,6 +10,7 @@ import jadx.core.dex.nodes.BlockNode; ...@@ -10,6 +10,7 @@ import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.Edge; import jadx.core.dex.nodes.Edge;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.trycatch.CatchAttr;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.exceptions.JadxRuntimeException; import jadx.core.utils.exceptions.JadxRuntimeException;
...@@ -396,6 +397,10 @@ public class BlockProcessor extends AbstractVisitor { ...@@ -396,6 +397,10 @@ public class BlockProcessor extends AbstractVisitor {
|| !block.getSuccessors().isEmpty()) { || !block.getSuccessors().isEmpty()) {
LOG.error("Block {} not deleted, method: {}", block, mth); LOG.error("Block {} not deleted, method: {}", block, mth);
} else { } else {
CatchAttr catchAttr = block.get(AType.CATCH_BLOCK);
if (catchAttr != null) {
catchAttr.getTryBlock().removeBlock(mth, block);
}
it.remove(); it.remove();
} }
} }
......
...@@ -14,9 +14,14 @@ public final class BlocksRemoveInfo { ...@@ -14,9 +14,14 @@ public final class BlocksRemoveInfo {
private final Set<BlocksPair> processed = new HashSet<BlocksPair>(); private final Set<BlocksPair> processed = new HashSet<BlocksPair>();
private final Set<BlocksPair> outs = new HashSet<BlocksPair>(); private final Set<BlocksPair> outs = new HashSet<BlocksPair>();
private final Map<RegisterArg, RegisterArg> regMap = new HashMap<RegisterArg, RegisterArg>(); private final Map<RegisterArg, RegisterArg> regMap = new HashMap<RegisterArg, RegisterArg>();
private final BlocksPair start;
private BlocksPair start;
private BlocksPair end;
private int startSplitIndex; private int startSplitIndex;
private int endSplitIndex;
private BlockNode startPredecessor;
public BlocksRemoveInfo(BlocksPair start) { public BlocksRemoveInfo(BlocksPair start) {
this.start = start; this.start = start;
...@@ -34,6 +39,18 @@ public final class BlocksRemoveInfo { ...@@ -34,6 +39,18 @@ public final class BlocksRemoveInfo {
return start; return start;
} }
public void setStart(BlocksPair start) {
this.start = start;
}
public BlocksPair getEnd() {
return end;
}
public void setEnd(BlocksPair end) {
this.end = end;
}
public int getStartSplitIndex() { public int getStartSplitIndex() {
return startSplitIndex; return startSplitIndex;
} }
...@@ -42,6 +59,22 @@ public final class BlocksRemoveInfo { ...@@ -42,6 +59,22 @@ public final class BlocksRemoveInfo {
this.startSplitIndex = startSplitIndex; this.startSplitIndex = startSplitIndex;
} }
public int getEndSplitIndex() {
return endSplitIndex;
}
public void setEndSplitIndex(int endSplitIndex) {
this.endSplitIndex = endSplitIndex;
}
public void setStartPredecessor(BlockNode startPredecessor) {
this.startPredecessor = startPredecessor;
}
public BlockNode getStartPredecessor() {
return startPredecessor;
}
public Map<RegisterArg, RegisterArg> getRegMap() { public Map<RegisterArg, RegisterArg> getRegMap() {
return regMap; return regMap;
} }
...@@ -69,6 +102,7 @@ public final class BlocksRemoveInfo { ...@@ -69,6 +102,7 @@ public final class BlocksRemoveInfo {
@Override @Override
public String toString() { public String toString() {
return "BRI start: " + start return "BRI start: " + start
+ ", end: " + end
+ ", list: " + processed + ", list: " + processed
+ ", outs: " + outs + ", outs: " + outs
+ ", regMap: " + regMap + ", regMap: " + regMap
......
...@@ -878,7 +878,6 @@ public class RegionMaker { ...@@ -878,7 +878,6 @@ public class RegionMaker {
} }
} }
// TODO add blocks common for several handlers to some region
private void processExcHandler(ExceptionHandler handler, Set<BlockNode> exits) { private void processExcHandler(ExceptionHandler handler, Set<BlockNode> exits) {
BlockNode start = handler.getHandlerBlock(); BlockNode start = handler.getHandlerBlock();
if (start == null) { if (start == null) {
......
...@@ -12,6 +12,7 @@ import jadx.core.dex.nodes.BlockNode; ...@@ -12,6 +12,7 @@ import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.InsnList;
import jadx.core.utils.InstructionRemover; import jadx.core.utils.InstructionRemover;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
import jadx.core.utils.exceptions.JadxRuntimeException; import jadx.core.utils.exceptions.JadxRuntimeException;
...@@ -20,6 +21,7 @@ import java.util.ArrayList; ...@@ -20,6 +21,7 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.BitSet; import java.util.BitSet;
import java.util.Deque; import java.util.Deque;
import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
...@@ -41,10 +43,18 @@ public class SSATransform extends AbstractVisitor { ...@@ -41,10 +43,18 @@ public class SSATransform extends AbstractVisitor {
placePhi(mth, i, la); placePhi(mth, i, la);
} }
renameVariables(mth); renameVariables(mth);
fixLastTryCatchAssign(mth);
if (removeUselessPhi(mth)) { fixLastAssignInTry(mth);
renameVariables(mth); removeBlockerInsns(mth);
boolean repeatFix;
int k = 0;
do {
repeatFix = fixUselessPhi(mth);
if (k++ > 50) {
throw new JadxRuntimeException("Phi nodes fix limit reached!");
} }
} while (repeatFix);
} }
private static void placePhi(MethodNode mth, int regNum, LiveVarAnalysis la) { private static void placePhi(MethodNode mth, int regNum, LiveVarAnalysis la) {
...@@ -65,7 +75,7 @@ public class SSATransform extends AbstractVisitor { ...@@ -65,7 +75,7 @@ public class SSATransform extends AbstractVisitor {
for (int id = domFrontier.nextSetBit(0); id >= 0; id = domFrontier.nextSetBit(id + 1)) { for (int id = domFrontier.nextSetBit(0); id >= 0; id = domFrontier.nextSetBit(id + 1)) {
if (!hasPhi.get(id) && la.isLive(id, regNum)) { if (!hasPhi.get(id) && la.isLive(id, regNum)) {
BlockNode df = blocks.get(id); BlockNode df = blocks.get(id);
addPhi(df, regNum); addPhi(mth, df, regNum);
hasPhi.set(id); hasPhi.set(id);
if (!processed.get(id)) { if (!processed.get(id)) {
processed.set(id); processed.set(id);
...@@ -76,19 +86,31 @@ public class SSATransform extends AbstractVisitor { ...@@ -76,19 +86,31 @@ public class SSATransform extends AbstractVisitor {
} }
} }
private static void addPhi(BlockNode block, int regNum) { private static void addPhi(MethodNode mth, BlockNode block, int regNum) {
PhiListAttr phiList = block.get(AType.PHI_LIST); PhiListAttr phiList = block.get(AType.PHI_LIST);
if (phiList == null) { if (phiList == null) {
phiList = new PhiListAttr(); phiList = new PhiListAttr();
block.addAttr(phiList); block.addAttr(phiList);
} }
PhiInsn phiInsn = new PhiInsn(regNum, block.getPredecessors().size()); int size = block.getPredecessors().size();
if (mth.getEnterBlock() == block) {
for (RegisterArg arg : mth.getArguments(true)) {
if (arg.getRegNum() == regNum) {
size++;
break;
}
}
}
PhiInsn phiInsn = new PhiInsn(regNum, size);
phiList.getList().add(phiInsn); phiList.getList().add(phiInsn);
phiInsn.setOffset(block.getStartOffset()); phiInsn.setOffset(block.getStartOffset());
block.getInstructions().add(0, phiInsn); block.getInstructions().add(0, phiInsn);
} }
private static void renameVariables(MethodNode mth) { private static void renameVariables(MethodNode mth) {
if (!mth.getSVars().isEmpty()) {
throw new JadxRuntimeException("SSA rename variables already executed");
}
int regsCount = mth.getRegsCount(); int regsCount = mth.getRegsCount();
SSAVar[] vars = new SSAVar[regsCount]; SSAVar[] vars = new SSAVar[regsCount];
int[] versions = new int[regsCount]; int[] versions = new int[regsCount];
...@@ -97,7 +119,25 @@ public class SSATransform extends AbstractVisitor { ...@@ -97,7 +119,25 @@ public class SSATransform extends AbstractVisitor {
int regNum = arg.getRegNum(); int regNum = arg.getRegNum();
vars[regNum] = mth.makeNewSVar(regNum, versions, arg); vars[regNum] = mth.makeNewSVar(regNum, versions, arg);
} }
renameVar(mth, vars, versions, mth.getEnterBlock()); BlockNode enterBlock = mth.getEnterBlock();
initPhiInEnterBlock(vars, enterBlock);
renameVar(mth, vars, versions, enterBlock);
}
private static void initPhiInEnterBlock(SSAVar[] vars, BlockNode enterBlock) {
PhiListAttr phiList = enterBlock.get(AType.PHI_LIST);
if (phiList != null) {
for (PhiInsn phiInsn : phiList.getList()) {
int regNum = phiInsn.getResult().getRegNum();
SSAVar var = vars[regNum];
if (var == null) {
continue;
}
RegisterArg arg = phiInsn.bindArg(enterBlock);
var.use(arg);
var.setUsedInPhi(phiInsn);
}
}
} }
private static void renameVar(MethodNode mth, SSAVar[] vars, int[] vers, BlockNode block) { private static void renameVar(MethodNode mth, SSAVar[] vars, int[] vers, BlockNode block) {
...@@ -129,20 +169,14 @@ public class SSATransform extends AbstractVisitor { ...@@ -129,20 +169,14 @@ public class SSATransform extends AbstractVisitor {
if (phiList == null) { if (phiList == null) {
continue; continue;
} }
int j = s.getPredecessors().indexOf(block);
if (j == -1) {
throw new JadxRuntimeException("Can't find predecessor for " + block + " " + s);
}
for (PhiInsn phiInsn : phiList.getList()) { for (PhiInsn phiInsn : phiList.getList()) {
if (j >= phiInsn.getArgsCount()) {
continue;
}
int regNum = phiInsn.getResult().getRegNum(); int regNum = phiInsn.getResult().getRegNum();
SSAVar var = vars[regNum]; SSAVar var = vars[regNum];
if (var == null) { if (var == null) {
continue; continue;
} }
var.use(phiInsn.getArg(j)); RegisterArg arg = phiInsn.bindArg(block);
var.use(arg);
var.setUsedInPhi(phiInsn); var.setUsedInPhi(phiInsn);
} }
} }
...@@ -152,14 +186,23 @@ public class SSATransform extends AbstractVisitor { ...@@ -152,14 +186,23 @@ public class SSATransform extends AbstractVisitor {
System.arraycopy(inputVars, 0, vars, 0, vars.length); System.arraycopy(inputVars, 0, vars, 0, vars.length);
} }
private static void fixLastTryCatchAssign(MethodNode mth) { /**
* Fix last try/catch assign instruction
*/
private static void fixLastAssignInTry(MethodNode mth) {
for (BlockNode block : mth.getBasicBlocks()) { for (BlockNode block : mth.getBasicBlocks()) {
PhiListAttr phiList = block.get(AType.PHI_LIST); PhiListAttr phiList = block.get(AType.PHI_LIST);
if (phiList == null || !block.contains(AType.EXC_HANDLER)) { if (phiList != null && block.contains(AType.EXC_HANDLER)) {
continue;
}
for (PhiInsn phi : phiList.getList()) { for (PhiInsn phi : phiList.getList()) {
for (int i = 0; i < phi.getArgsCount(); i++) { fixPhiInTryCatch(phi);
}
}
}
}
private static void fixPhiInTryCatch(PhiInsn phi) {
int argsCount = phi.getArgsCount();
for (int i = 0; i < argsCount; i++) {
RegisterArg arg = phi.getArg(i); RegisterArg arg = phi.getArg(i);
InsnNode parentInsn = arg.getAssignInsn(); InsnNode parentInsn = arg.getAssignInsn();
if (parentInsn != null if (parentInsn != null
...@@ -169,10 +212,32 @@ public class SSATransform extends AbstractVisitor { ...@@ -169,10 +212,32 @@ public class SSATransform extends AbstractVisitor {
} }
} }
} }
private static boolean removeBlockerInsns(MethodNode mth) {
boolean removed = false;
for (BlockNode block : mth.getBasicBlocks()) {
PhiListAttr phiList = block.get(AType.PHI_LIST);
if (phiList == null) {
continue;
}
// check if args must be removed
for (PhiInsn phi : phiList.getList()) {
for (int i = 0; i < phi.getArgsCount(); i++) {
RegisterArg arg = phi.getArg(i);
InsnNode parentInsn = arg.getAssignInsn();
if (parentInsn != null && parentInsn.contains(AFlag.REMOVE)) {
phi.removeArg(arg);
InstructionRemover.remove(mth, block, parentInsn);
removed = true;
}
}
}
} }
return removed;
} }
private static boolean removeUselessPhi(MethodNode mth) { private static boolean fixUselessPhi(MethodNode mth) {
boolean changed = false;
List<PhiInsn> insnToRemove = new ArrayList<PhiInsn>(); List<PhiInsn> insnToRemove = new ArrayList<PhiInsn>();
for (SSAVar var : mth.getSVars()) { for (SSAVar var : mth.getSVars()) {
// phi result not used // phi result not used
...@@ -180,6 +245,7 @@ public class SSATransform extends AbstractVisitor { ...@@ -180,6 +245,7 @@ public class SSATransform extends AbstractVisitor {
InsnNode assignInsn = var.getAssign().getParentInsn(); InsnNode assignInsn = var.getAssign().getParentInsn();
if (assignInsn != null && assignInsn.getType() == InsnType.PHI) { if (assignInsn != null && assignInsn.getType() == InsnType.PHI) {
insnToRemove.add((PhiInsn) assignInsn); insnToRemove.add((PhiInsn) assignInsn);
changed = true;
} }
} }
} }
...@@ -188,41 +254,53 @@ public class SSATransform extends AbstractVisitor { ...@@ -188,41 +254,53 @@ public class SSATransform extends AbstractVisitor {
if (phiList == null) { if (phiList == null) {
continue; continue;
} }
for (PhiInsn phi : phiList.getList()) { Iterator<PhiInsn> it = phiList.getList().iterator();
removePhiWithSameArgs(phi, insnToRemove); while (it.hasNext()) {
PhiInsn phi = it.next();
if (fixPhiWithSameArgs(mth, block, phi)) {
it.remove();
changed = true;
} }
} }
return removePhiList(mth, insnToRemove); }
removePhiList(mth, insnToRemove);
return changed;
} }
private static void removePhiWithSameArgs(PhiInsn phi, List<PhiInsn> insnToRemove) { private static boolean fixPhiWithSameArgs(MethodNode mth, BlockNode block, PhiInsn phi) {
if (phi.getArgsCount() <= 1) { if (phi.getArgsCount() == 0) {
insnToRemove.add(phi); for (RegisterArg useArg : phi.getResult().getSVar().getUseList()) {
return; InsnNode useInsn = useArg.getParentInsn();
if (useInsn != null && useInsn.getType() == InsnType.PHI) {
phi.removeArg(useArg);
}
} }
InstructionRemover.remove(mth, block, phi);
return true;
}
boolean allSame = phi.getArgsCount() == 1 || isSameArgs(phi);
if (!allSame) {
return false;
}
return replacePhiWithMove(mth, block, phi, phi.getArg(0));
}
private static boolean isSameArgs(PhiInsn phi) {
boolean allSame = true; boolean allSame = true;
SSAVar var = phi.getArg(0).getSVar(); SSAVar var = null;
for (int i = 1; i < phi.getArgsCount(); i++) { for (int i = 0; i < phi.getArgsCount(); i++) {
if (var != phi.getArg(i).getSVar()) { RegisterArg arg = phi.getArg(i);
if (var == null) {
var = arg.getSVar();
} else if (var != arg.getSVar()) {
allSame = false; allSame = false;
break; break;
} }
} }
if (allSame) { return allSame;
// replace
insnToRemove.add(phi);
SSAVar assign = phi.getResult().getSVar();
for (RegisterArg arg : new ArrayList<RegisterArg>(assign.getUseList())) {
assign.removeUse(arg);
var.use(arg);
}
}
} }
private static boolean removePhiList(MethodNode mth, List<PhiInsn> insnToRemove) { private static boolean removePhiList(MethodNode mth, List<PhiInsn> insnToRemove) {
if (insnToRemove.isEmpty()) {
return false;
}
for (BlockNode block : mth.getBasicBlocks()) { for (BlockNode block : mth.getBasicBlocks()) {
PhiListAttr phiList = block.get(AType.PHI_LIST); PhiListAttr phiList = block.get(AType.PHI_LIST);
if (phiList == null) { if (phiList == null) {
...@@ -232,6 +310,9 @@ public class SSATransform extends AbstractVisitor { ...@@ -232,6 +310,9 @@ public class SSATransform extends AbstractVisitor {
for (PhiInsn phiInsn : insnToRemove) { for (PhiInsn phiInsn : insnToRemove) {
if (list.remove(phiInsn)) { if (list.remove(phiInsn)) {
for (InsnArg arg : phiInsn.getArguments()) { for (InsnArg arg : phiInsn.getArguments()) {
if (arg == null) {
continue;
}
SSAVar sVar = ((RegisterArg) arg).getSVar(); SSAVar sVar = ((RegisterArg) arg).getSVar();
if (sVar != null) { if (sVar != null) {
sVar.setUsedInPhi(null); sVar.setUsedInPhi(null);
...@@ -247,4 +328,67 @@ public class SSATransform extends AbstractVisitor { ...@@ -247,4 +328,67 @@ public class SSATransform extends AbstractVisitor {
insnToRemove.clear(); insnToRemove.clear();
return true; return true;
} }
private static boolean replacePhiWithMove(MethodNode mth, BlockNode block, PhiInsn phi, RegisterArg arg) {
List<InsnNode> insns = block.getInstructions();
int phiIndex = InsnList.getIndex(insns, phi);
if (phiIndex == -1) {
return false;
}
SSAVar assign = phi.getResult().getSVar();
SSAVar argVar = arg.getSVar();
if (argVar != null) {
argVar.removeUse(arg);
argVar.setUsedInPhi(null);
}
// try inline
if (inlinePhiInsn(mth, block, phi)) {
insns.remove(phiIndex);
} else {
assign.setUsedInPhi(null);
InsnNode m = new InsnNode(InsnType.MOVE, 1);
m.add(AFlag.SYNTHETIC);
m.setResult(phi.getResult());
m.addArg(arg);
arg.getSVar().use(arg);
insns.set(phiIndex, m);
}
return true;
}
private static boolean inlinePhiInsn(MethodNode mth, BlockNode block, PhiInsn phi) {
SSAVar resVar = phi.getResult().getSVar();
if (resVar == null) {
return false;
}
RegisterArg arg = phi.getArg(0);
if (arg.getSVar() == null) {
return false;
}
List<RegisterArg> useList = resVar.getUseList();
for (RegisterArg useArg : new ArrayList<RegisterArg>(useList)) {
InsnNode useInsn = useArg.getParentInsn();
if (useInsn == null || useInsn == phi) {
return false;
}
useArg.getSVar().removeUse(useArg);
RegisterArg inlArg = arg.duplicate();
if (!useInsn.replaceArg(useArg, inlArg)) {
return false;
}
inlArg.getSVar().use(inlArg);
inlArg.setName(useArg.getName());
inlArg.setType(useArg.getType());
}
if (block.contains(AType.EXC_HANDLER)) {
// don't inline into exception handler
InsnNode assignInsn = arg.getAssignInsn();
if (assignInsn != null) {
assignInsn.add(AFlag.DONT_INLINE);
}
}
InstructionRemover.unbindInsn(mth, phi);
return true;
}
} }
...@@ -71,7 +71,10 @@ public class TypeInference extends AbstractVisitor { ...@@ -71,7 +71,10 @@ public class TypeInference extends AbstractVisitor {
for (int i = 0; i < phi.getArgsCount(); i++) { for (int i = 0; i < phi.getArgsCount(); i++) {
RegisterArg arg = phi.getArg(i); RegisterArg arg = phi.getArg(i);
arg.setType(type); arg.setType(type);
arg.getSVar().setName(phi.getResult().getName()); SSAVar sVar = arg.getSVar();
if (sVar != null) {
sVar.setName(phi.getResult().getName());
}
} }
} }
......
...@@ -29,12 +29,11 @@ public final class InsnList implements Iterable<InsnNode> { ...@@ -29,12 +29,11 @@ public final class InsnList implements Iterable<InsnNode> {
} }
public static int getIndex(List<InsnNode> list, InsnNode insn) { public static int getIndex(List<InsnNode> list, InsnNode insn) {
int i = 0; int size = list.size();
for (InsnNode curObj : list) { for (int i = 0; i < size; i++) {
if (curObj == insn) { if (list.get(i) == insn) {
return i; return i;
} }
i++;
} }
return -1; return -1;
} }
......
package jadx.core.utils; package jadx.core.utils;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg; import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
...@@ -63,16 +65,40 @@ public class InstructionRemover { ...@@ -63,16 +65,40 @@ public class InstructionRemover {
} }
public static void unbindInsn(MethodNode mth, InsnNode insn) { public static void unbindInsn(MethodNode mth, InsnNode insn) {
RegisterArg r = insn.getResult(); unbindResult(mth, insn);
if (r != null && r.getSVar() != null) {
mth.removeSVar(r.getSVar());
}
for (InsnArg arg : insn.getArguments()) { for (InsnArg arg : insn.getArguments()) {
unbindArgUsage(mth, arg); unbindArgUsage(mth, arg);
} }
if (insn.getType() == InsnType.PHI) {
for (InsnArg arg : insn.getArguments()) {
if (arg instanceof RegisterArg) {
fixUsedInPhiFlag((RegisterArg) arg);
}
}
}
insn.add(AFlag.INCONSISTENT_CODE); insn.add(AFlag.INCONSISTENT_CODE);
} }
public static void fixUsedInPhiFlag(RegisterArg useReg) {
PhiInsn usedIn = null;
for (RegisterArg reg : useReg.getSVar().getUseList()) {
InsnNode parentInsn = reg.getParentInsn();
if (parentInsn != null
&& parentInsn.getType() == InsnType.PHI
&& parentInsn.containsArg(useReg)) {
usedIn = (PhiInsn) parentInsn;
}
}
useReg.getSVar().setUsedInPhi(usedIn);
}
public static void unbindResult(MethodNode mth, InsnNode insn) {
RegisterArg r = insn.getResult();
if (r != null && r.getSVar() != null) {
mth.removeSVar(r.getSVar());
}
}
public static void unbindArgUsage(MethodNode mth, InsnArg arg) { public static void unbindArgUsage(MethodNode mth, InsnArg arg) {
if (arg instanceof RegisterArg) { if (arg instanceof RegisterArg) {
RegisterArg reg = (RegisterArg) arg; RegisterArg reg = (RegisterArg) arg;
...@@ -122,6 +148,9 @@ public class InstructionRemover { ...@@ -122,6 +148,9 @@ public class InstructionRemover {
} }
public static void removeAll(MethodNode mth, BlockNode block, List<InsnNode> insns) { public static void removeAll(MethodNode mth, BlockNode block, List<InsnNode> insns) {
if (insns.isEmpty()) {
return;
}
removeAll(mth, block.getInstructions(), insns); removeAll(mth, block.getInstructions(), insns);
} }
......
...@@ -154,6 +154,6 @@ public class ManifestAttributes { ...@@ -154,6 +154,6 @@ public class ManifestAttributes {
return sb.deleteCharAt(sb.length() - 1).toString(); return sb.deleteCharAt(sb.length() - 1).toString();
} }
} }
return "UNKNOWN_DATA_" + Integer.toHexString(value); return "UNKNOWN_DATA_0x" + Integer.toHexString(value);
} }
} }
...@@ -13,7 +13,7 @@ public class TestArgInline extends IntegrationTest { ...@@ -13,7 +13,7 @@ public class TestArgInline extends IntegrationTest {
public static class TestCls { public static class TestCls {
public void method(int a) { public void test(int a) {
while (a < 10) { while (a < 10) {
int b = a + 1; int b = a + 1;
a = b; a = b;
......
...@@ -52,7 +52,7 @@ public class TestContinueInLoop2 extends IntegrationTest { ...@@ -52,7 +52,7 @@ public class TestContinueInLoop2 extends IntegrationTest {
TryCatchBlock catchBlock = catchAttr.getTryBlock(); TryCatchBlock catchBlock = catchAttr.getTryBlock();
if (handlerBlock != catchBlock) { if (handlerBlock != catchBlock) {
handlerBlock.merge(mth, catchBlock); handlerBlock.merge(mth, catchBlock);
catchBlock.removeInsn(insn); catchBlock.removeInsn(mth, insn);
} }
} }
} }
......
package jadx.tests.integration.trycatch;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestFinally extends IntegrationTest {
public static class TestCls {
private static final String DISPLAY_NAME = "name";
String test(Context context, Object uri) {
Cursor cursor = null;
try {
String[] projection = {DISPLAY_NAME};
cursor = context.query(uri, projection);
int columnIndex = cursor.getColumnIndexOrThrow(DISPLAY_NAME);
cursor.moveToFirst();
return cursor.getString(columnIndex);
} finally {
if (cursor != null) {
cursor.close();
}
}
}
private class Context {
public Cursor query(Object o, String[] s) {
return null;
}
}
private class Cursor {
public void close() {
}
public void moveToFirst() {
}
public int getColumnIndexOrThrow(String s) {
return 0;
}
public String getString(int i) {
return null;
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("cursor.getString(columnIndex);"));
assertThat(code, not(containsOne("String str = true;")));
}
}
package jadx.tests.integration.trycatch;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.junit.Test;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestFinally2 extends IntegrationTest {
public static class TestCls {
public Result test(byte[] data) throws IOException {
InputStream inputStream = null;
try {
inputStream = getInputStream(data);
decode(inputStream);
return new Result(400);
} finally {
closeQuietly(inputStream);
}
}
public static final class Result {
private final int mCode;
public Result(int code) {
mCode = code;
}
public int getCode() {
return mCode;
}
}
private InputStream getInputStream(byte[] data) throws IOException {
return new ByteArrayInputStream(data);
}
private int decode(InputStream inputStream) throws IOException {
return inputStream.available();
}
private void closeQuietly(InputStream is) {
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("decode(inputStream);"));
// TODO
// assertThat(code, not(containsOne("result =")));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment