Commit 3f08c99f authored by Skylot's avatar Skylot

core: use ternary operator

parent e3606d1b
......@@ -14,10 +14,9 @@ import jadx.core.dex.visitors.MethodInlinerVisitor;
import jadx.core.dex.visitors.ModVisitor;
import jadx.core.dex.visitors.SimplifyVisitor;
import jadx.core.dex.visitors.regions.CheckRegions;
import jadx.core.dex.visitors.regions.CleanRegions;
import jadx.core.dex.visitors.regions.PostRegionVisitor;
import jadx.core.dex.visitors.regions.ProcessVariables;
import jadx.core.dex.visitors.regions.RegionMakerVisitor;
import jadx.core.dex.visitors.regions.TernaryVisitor;
import jadx.core.dex.visitors.typeresolver.FinishTypeResolver;
import jadx.core.dex.visitors.typeresolver.TypeResolver;
import jadx.core.utils.Utils;
......@@ -68,10 +67,10 @@ public class Jadx {
passes.add(new DotGraphVisitor(outDir, false));
}
passes.add(new CodeShrinker());
passes.add(new RegionMakerVisitor());
passes.add(new PostRegionVisitor());
passes.add(new TernaryVisitor());
passes.add(new CodeShrinker());
passes.add(new SimplifyVisitor());
passes.add(new ProcessVariables());
passes.add(new CheckRegions());
......@@ -82,7 +81,6 @@ public class Jadx {
passes.add(new MethodInlinerVisitor());
passes.add(new ClassModifier());
passes.add(new CleanRegions());
}
passes.add(new CodeGen(args));
return passes;
......
package jadx.core.codegen;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.regions.Compare;
import jadx.core.dex.regions.IfCondition;
import jadx.core.utils.ErrorsCounter;
import jadx.core.utils.exceptions.CodegenException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ConditionGen {
private static final Logger LOG = LoggerFactory.getLogger(ConditionGen.class);
static String make(InsnGen insnGen, IfCondition condition) throws CodegenException {
switch (condition.getMode()) {
case COMPARE:
return makeCompare(insnGen, condition.getCompare());
case NOT:
return "!(" + make(insnGen, condition.getArgs().get(0)) + ")";
case AND:
case OR:
String mode = condition.getMode() == IfCondition.Mode.AND ? " && " : " || ";
StringBuilder sb = new StringBuilder();
for (IfCondition arg : condition.getArgs()) {
if (sb.length() != 0) {
sb.append(mode);
}
String s = make(insnGen, arg);
if (arg.isCompare()) {
sb.append(s);
} else {
sb.append('(').append(s).append(')');
}
}
return sb.toString();
default:
return "??" + condition.toString();
}
}
private static String makeCompare(InsnGen insnGen, Compare compare) throws CodegenException {
IfOp op = compare.getOp();
InsnArg firstArg = compare.getA();
InsnArg secondArg = compare.getB();
if (firstArg.getType().equals(ArgType.BOOLEAN)
&& secondArg.isLiteral()
&& secondArg.getType().equals(ArgType.BOOLEAN)) {
LiteralArg lit = (LiteralArg) secondArg;
if (lit.getLiteral() == 0) {
op = op.invert();
}
if (op == IfOp.EQ) {
// == true
return insnGen.arg(firstArg, false).toString();
} else if (op == IfOp.NE) {
// != true
if (isWrapNeeded(firstArg)) {
return "!(" + insnGen.arg(firstArg) + ")";
} else {
return "!" + insnGen.arg(firstArg);
}
}
LOG.warn(ErrorsCounter.formatErrorMsg(insnGen.mth, "Unsupported boolean condition " + op.getSymbol()));
}
return insnGen.arg(firstArg, isWrapNeeded(firstArg))
+ " " + op.getSymbol() + " "
+ insnGen.arg(secondArg, isWrapNeeded(secondArg));
}
private static boolean isWrapNeeded(InsnArg arg) {
if (!arg.isInsnWrap()) {
return false;
}
InsnNode insn = ((InsnWrapArg) arg).getWrapInsn();
if (insn.getType() == InsnType.ARITH) {
ArithNode arith = ((ArithNode) insn);
switch (arith.getOp()) {
case ADD:
case SUB:
case MUL:
case DIV:
case REM:
return false;
}
} else if (insn.getType() == InsnType.INVOKE) {
return false;
}
return true;
}
}
......@@ -26,6 +26,7 @@ import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.NamedArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
......@@ -409,6 +410,7 @@ public class InsnGen {
break;
case TERNARY:
makeTernary((TernaryInsn) insn, code, state);
break;
case ARGS:
......@@ -689,6 +691,21 @@ public class InsnGen {
}
}
private void makeTernary(TernaryInsn insn, CodeWriter code, EnumSet<IGState> state) throws CodegenException {
String cond = ConditionGen.make(this, insn.getCondition());
CodeWriter th = arg(insn.getArg(0), false);
CodeWriter els = arg(insn.getArg(1), false);
if (th.toString().equals("true") && els.toString().equals("false")) {
code.add(cond);
} else {
if (state.contains(IGState.BODY_ONLY)) {
code.add("((").add(cond).add(')').add(" ? ").add(th).add(" : ").add(els).add(")");
} else {
code.add('(').add(cond).add(')').add(" ? ").add(th).add(" : ").add(els);
}
}
}
private void makeArith(ArithNode insn, CodeWriter code, EnumSet<IGState> state) throws CodegenException {
ArithOp op = insn.getOp();
CodeWriter v1 = arg(insn.getArg(0));
......
......@@ -22,6 +22,7 @@ import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Compare;
import jadx.core.dex.regions.IfCondition;
import jadx.core.dex.regions.IfRegion;
import jadx.core.dex.regions.LoopRegion;
......@@ -65,8 +66,7 @@ public class RegionGen extends InsnGen {
}
}
} else if (cont instanceof IfRegion) {
code.startLine();
makeIf((IfRegion) cont, code);
makeIf((IfRegion) cont, code, true);
} else if (cont instanceof SwitchRegion) {
makeSwitch((SwitchRegion) cont, code);
} else if (cont instanceof LoopRegion) {
......@@ -110,8 +110,15 @@ public class RegionGen extends InsnGen {
}
}
private void makeIf(IfRegion region, CodeWriter code) throws CodegenException {
code.add("if (").add(makeCondition(region.getCondition())).add(") {");
private void makeIf(IfRegion region, CodeWriter code, boolean newLine) throws CodegenException {
if (region.getTernRegion() != null) {
makeSimpleBlock(region.getTernRegion().getBlock(), code);
return;
}
if (newLine) {
code.startLine();
}
code.add("if (").add(ConditionGen.make(this, region.getCondition())).add(") {");
makeRegionIndent(code, region.getThenRegion());
code.startLine('}');
......@@ -124,8 +131,11 @@ public class RegionGen extends InsnGen {
Region re = (Region) els;
List<IContainer> subBlocks = re.getSubBlocks();
if (subBlocks.size() == 1 && subBlocks.get(0) instanceof IfRegion) {
makeIf((IfRegion) subBlocks.get(0), code);
return;
IfRegion ifRegion = (IfRegion) subBlocks.get(0);
if (ifRegion.getAttributes().contains(AttributeFlag.ELSE_IF_CHAIN)) {
makeIf(ifRegion, code, false);
return;
}
}
}
......@@ -158,12 +168,13 @@ public class RegionGen extends InsnGen {
return code;
}
String condStr = ConditionGen.make(this, condition);
if (region.isConditionAtEnd()) {
code.startLine("do {");
makeRegionIndent(code, region.getBody());
code.startLine("} while (").add(makeCondition(condition)).add(");");
code.startLine("} while (").add(condStr).add(");");
} else {
code.startLine("while (").add(makeCondition(condition)).add(") {");
code.startLine("while (").add(condStr).add(") {");
makeRegionIndent(code, region.getBody());
code.startLine('}');
}
......@@ -203,7 +214,7 @@ public class RegionGen extends InsnGen {
}
}
private String makeCompare(IfCondition.Compare compare) throws CodegenException {
private String makeCompare(Compare compare) throws CodegenException {
IfOp op = compare.getOp();
InsnArg firstArg = compare.getA();
InsnArg secondArg = compare.getB();
......@@ -269,7 +280,6 @@ public class RegionGen extends InsnGen {
code.startLine("default:");
makeCaseBlock(sw.getDefaultCase(), code);
}
code.startLine('}');
return code;
}
......
......@@ -19,5 +19,7 @@ public enum AttributeFlag {
SKIP_FIRST_ARG,
ANONYMOUS_CONSTRUCTOR,
ELSE_IF_CHAIN,
INCONSISTENT_CODE, // warning about incorrect decompilation
}
......@@ -20,17 +20,6 @@ public class IfNode extends GotoNode {
private BlockNode thenBlock;
private BlockNode elseBlock;
public IfNode(int targ, InsnArg then, InsnArg els) {
super(InsnType.IF, targ);
addArg(then);
if (els == null) {
zeroCmp = true;
} else {
zeroCmp = false;
addArg(els);
}
}
public IfNode(DecodedInstruction insn, IfOp op) {
super(InsnType.IF, insn.getTarget());
this.op = op;
......@@ -84,7 +73,6 @@ public class IfNode extends GotoNode {
} else {
elseBlock = selectOther(thenBlock, curBlock.getSuccessors());
}
target = thenBlock.getStartOffset();
}
public BlockNode getThenBlock() {
......
......@@ -43,6 +43,23 @@ public final class LiteralArg extends InsnArg {
}
@Override
public int hashCode() {
return (int) (literal ^ (literal >>> 32)) + 31 * getType().hashCode();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LiteralArg that = (LiteralArg) o;
return literal == that.literal && getType().equals(that.getType());
}
@Override
public String toString() {
try {
return "(" + TypeGen.literalToString(literal, getType()) + " " + typedVar + ")";
......
package jadx.core.dex.instructions.mods;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.regions.IfCondition;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.Utils;
public class TernaryInsn extends InsnNode {
private final IfCondition condition;
public TernaryInsn(IfCondition condition, RegisterArg result, InsnArg th, InsnArg els) {
super(InsnType.TERNARY, 2);
this.condition = condition;
setResult(result);
addArg(th);
addArg(els);
}
public IfCondition getCondition() {
return condition;
}
@Override
public String toString() {
return InsnUtils.formatOffset(offset) + ": TERNARY"
+ getResult() + " = "
+ Utils.listToString(getArguments());
}
}
package jadx.core.dex.regions;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.args.InsnArg;
public final class Compare {
private final IfNode insn;
public Compare(IfNode insn) {
this.insn = insn;
}
public IfOp getOp() {
return insn.getOp();
}
public InsnArg getA() {
return insn.getArg(0);
}
public InsnArg getB() {
if (insn.isZeroCmp()) {
return InsnArg.lit(0, getA().getType());
} else {
return insn.getArg(1);
}
}
public Compare invert() {
insn.invertCondition();
return this;
}
@Override
public String toString() {
return getA() + " " + getOp().getSymbol() + " " + getB();
}
}
package jadx.core.dex.regions;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
public final class IfCondition {
......@@ -37,40 +38,6 @@ public final class IfCondition {
}
}
public static final class Compare {
private final IfNode insn;
public Compare(IfNode insn) {
this.insn = insn;
}
public IfOp getOp() {
return insn.getOp();
}
public InsnArg getA() {
return insn.getArg(0);
}
public InsnArg getB() {
if (insn.isZeroCmp()) {
return InsnArg.lit(0, getA().getType());
} else {
return insn.getArg(1);
}
}
public Compare invert() {
insn.invertCondition();
return this;
}
@Override
public String toString() {
return getA() + " " + getOp().getSymbol() + " " + getB();
}
}
public static enum Mode {
COMPARE,
NOT,
......@@ -137,6 +104,25 @@ public final class IfCondition {
throw new JadxRuntimeException("Unknown mode for invert: " + mode);
}
public List<RegisterArg> getRegisterArgs() {
List<RegisterArg> list = new LinkedList<RegisterArg>();
if (mode == Mode.COMPARE) {
InsnArg a = compare.getA();
if (a.isRegister()) {
list.add((RegisterArg) a);
}
InsnArg b = compare.getA();
if (a.isRegister()) {
list.add((RegisterArg) b);
}
} else {
for (IfCondition arg : args) {
list.addAll(arg.getRegisterArgs());
}
}
return list;
}
@Override
public String toString() {
switch (mode) {
......
......@@ -16,6 +16,8 @@ public final class IfRegion extends AbstractRegion {
private IContainer thenRegion;
private IContainer elseRegion;
private TernaryRegion ternRegion;
public IfRegion(IRegion parent, BlockNode header) {
super(parent);
assert header.getInstructions().size() == 1;
......@@ -47,8 +49,23 @@ public final class IfRegion extends AbstractRegion {
this.elseRegion = elseRegion;
}
public BlockNode getHeader() {
return header;
}
public void setTernRegion(TernaryRegion ternRegion) {
this.ternRegion = ternRegion;
}
public TernaryRegion getTernRegion() {
return ternRegion;
}
@Override
public List<IContainer> getSubBlocks() {
if (ternRegion != null) {
return ternRegion.getSubBlocks();
}
ArrayList<IContainer> all = new ArrayList<IContainer>(3);
all.add(header);
if (thenRegion != null) {
......@@ -62,6 +79,9 @@ public final class IfRegion extends AbstractRegion {
@Override
public String toString() {
if (ternRegion != null) {
return ternRegion.toString();
}
return "IF(" + condition + ") then " + thenRegion + " else " + elseRegion;
}
}
package jadx.core.dex.regions;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import java.util.Collections;
import java.util.List;
public final class TernaryRegion extends AbstractRegion {
private IBlock container;
public TernaryRegion(IRegion parent, BlockNode block) {
super(parent);
this.container = block;
}
public IBlock getBlock() {
return container;
}
@Override
public List<IContainer> getSubBlocks() {
return Collections.singletonList((IContainer) container);
}
@Override
public String toString() {
return "TERN:" + container;
}
}
......@@ -47,6 +47,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
}
mth.initBasicBlocks();
makeBasicBlocks(mth);
processBlocksTree(mth);
BlockProcessingHelper.visit(mth);
mth.finishBasicBlocks();
}
......@@ -173,6 +174,9 @@ public class BlockMakerVisitor extends AbstractVisitor {
}
}
}
}
private static void processBlocksTree(MethodNode mth) {
computeDominators(mth);
markReturnBlocks(mth);
......@@ -189,7 +193,6 @@ public class BlockMakerVisitor extends AbstractVisitor {
throw new AssertionError("Can't fix method cfg: " + mth);
}
}
registerLoops(mth);
}
......@@ -369,7 +372,6 @@ public class BlockMakerVisitor extends AbstractVisitor {
if (mergeReturn(mth)) {
return true;
}
// TODO detect ternary operator
return false;
}
......@@ -468,6 +470,7 @@ public class BlockMakerVisitor extends AbstractVisitor {
insn.addArg(InsnArg.reg(arg.getRegNum(), arg.getType()));
}
insn.getAttributes().addAll(returnInsn.getAttributes());
insn.setOffset(returnInsn.getOffset());
return insn;
}
......
......@@ -6,6 +6,7 @@ import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
......@@ -17,16 +18,16 @@ import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Set;
public class CodeShrinker extends AbstractVisitor {
private static final Logger LOG = LoggerFactory.getLogger(CodeShrinker.class);
@Override
public void visit(MethodNode mth) {
shrinkMethod(mth);
}
public static void shrinkMethod(MethodNode mth) {
if (mth.isNoCode() || mth.getAttributes().contains(AttributeFlag.DONT_SHRINK)) {
return;
}
......@@ -48,13 +49,20 @@ public class CodeShrinker extends AbstractVisitor {
this.argsList = argsList;
this.pos = pos;
this.inlineBorder = pos;
this.args = new LinkedList<RegisterArg>();
this.args = getArgs(insn);
}
public static List<RegisterArg> getArgs(InsnNode insn) {
LinkedList<RegisterArg> args = new LinkedList<RegisterArg>();
addArgs(insn, args);
return args;
}
private static void addArgs(InsnNode insn, List<RegisterArg> args) {
if (insn.getType() == InsnType.CONSTRUCTOR) {
args.add(((ConstructorInsn) insn).getInstanceArg());
} else if (insn.getType() == InsnType.TERNARY) {
args.addAll(((TernaryInsn) insn).getCondition().getRegisterArgs());
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isRegister()) {
......@@ -85,6 +93,7 @@ public class CodeShrinker extends AbstractVisitor {
}
private boolean canMove(int from, int to) {
List<RegisterArg> movedArgs = argsList.get(from).getArgs();
from++;
if (from == to) {
// previous instruction or on edge of inline border
......@@ -98,13 +107,18 @@ public class CodeShrinker extends AbstractVisitor {
if (argsInfo.getInlinedInsn() == this) {
continue;
}
if (!argsInfo.insn.canReorder()) {
InsnNode curInsn = argsInfo.insn;
if (!curInsn.canReorder() || usedArgAssign(curInsn, movedArgs)) {
return false;
}
}
return true;
}
private static boolean usedArgAssign(InsnNode insn, List<RegisterArg> args) {
return insn.getResult() != null && args.contains(insn.getResult());
}
public WrapInfo inline(int assignInsnPos, RegisterArg arg) {
ArgsInfo argsInfo = argsList.get(assignInsnPos);
argsInfo.inlinedInsn = this;
......@@ -152,7 +166,10 @@ public class CodeShrinker extends AbstractVisitor {
}
}
private void shrinkBlock(MethodNode mth, BlockNode block) {
private static void shrinkBlock(MethodNode mth, BlockNode block) {
if (block.getInstructions().isEmpty()) {
return;
}
InsnList insnList = new InsnList(block.getInstructions());
int insnCount = insnList.size();
List<ArgsInfo> argsList = new ArrayList<ArgsInfo>(insnCount);
......@@ -186,35 +203,36 @@ public class CodeShrinker extends AbstractVisitor {
}
} else {
// another block
if (block.getPredecessors().size() == 1) {
BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn);
if (canMoveBetweenBlocks(assignInsn, assignBlock, block, argsInfo.getInsn())) {
arg.wrapInstruction(assignInsn);
InsnList.remove(assignBlock, assignInsn);
}
BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn);
if (assignBlock != null
&& canMoveBetweenBlocks(assignInsn, assignBlock, block, argsInfo.getInsn())) {
arg.wrapInstruction(assignInsn);
InsnList.remove(assignBlock, assignInsn);
}
}
}
}
for (WrapInfo wrapInfo : wrapList) {
wrapInfo.getArg().wrapInstruction(wrapInfo.getInsn());
}
for (WrapInfo wrapInfo : wrapList) {
insnList.remove(wrapInfo.getInsn());
if (!wrapList.isEmpty()) {
for (WrapInfo wrapInfo : wrapList) {
wrapInfo.getArg().wrapInstruction(wrapInfo.getInsn());
}
for (WrapInfo wrapInfo : wrapList) {
insnList.remove(wrapInfo.getInsn());
}
}
}
private boolean canMoveBetweenBlocks(InsnNode assignInsn, BlockNode assignBlock,
BlockNode useBlock, InsnNode useInsn) {
if (!useBlock.getPredecessors().contains(assignBlock)
&& !BlockUtils.isOnlyOnePathExists(assignBlock, useBlock)) {
private static boolean canMoveBetweenBlocks(InsnNode assignInsn, BlockNode assignBlock,
BlockNode useBlock, InsnNode useInsn) {
if (!BlockUtils.isPathExists(assignBlock, useBlock)) {
return false;
}
List<RegisterArg> args = ArgsInfo.getArgs(assignInsn);
boolean startCheck = false;
for (InsnNode insn : assignBlock.getInstructions()) {
if (startCheck) {
if (!insn.canReorder()) {
if (!insn.canReorder() || ArgsInfo.usedArgAssign(insn, args)) {
return false;
}
}
......@@ -222,26 +240,28 @@ public class CodeShrinker extends AbstractVisitor {
startCheck = true;
}
}
BlockNode next = assignBlock.getCleanSuccessors().get(0);
while (next != useBlock) {
for (InsnNode insn : assignBlock.getInstructions()) {
if (!insn.canReorder()) {
Set<BlockNode> pathsBlocks = BlockUtils.getAllPathsBlocks(assignBlock, useBlock);
pathsBlocks.remove(assignBlock);
pathsBlocks.remove(useBlock);
for (BlockNode block : pathsBlocks) {
for (InsnNode insn : block.getInstructions()) {
if (!insn.canReorder() || ArgsInfo.usedArgAssign(insn, args)) {
return false;
}
}
next = next.getCleanSuccessors().get(0);
}
for (InsnNode insn : useBlock.getInstructions()) {
if (insn == useInsn) {
return true;
}
if (!insn.canReorder()) {
if (!insn.canReorder() || ArgsInfo.usedArgAssign(insn, args)) {
return false;
}
}
throw new JadxRuntimeException("Can't process instruction move : " + assignBlock);
}
@Deprecated
public static InsnArg inlineArgument(MethodNode mth, RegisterArg arg) {
InsnNode assignInsn = arg.getAssignInsn();
if (assignInsn == null) {
......
......@@ -33,8 +33,7 @@ public class CheckRegions extends AbstractVisitor {
public void processBlock(MethodNode mth, IBlock container) {
if (container instanceof BlockNode) {
blocksInRegions.add((BlockNode) container);
} else {
LOG.warn("Not block node : " + container.getClass().getSimpleName());
}
}
};
......
......@@ -5,19 +5,16 @@ import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class CleanRegions extends AbstractVisitor {
public class CleanRegions {
private static final Logger LOG = LoggerFactory.getLogger(CleanRegions.class);
@Override
public void visit(MethodNode mth) throws JadxException {
public static void process(MethodNode mth) {
if (mth.isNoCode() || mth.getBasicBlocks().size() == 0) {
return;
}
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException;
public class PostRegionVisitor extends AbstractVisitor {
@Override
public void visit(MethodNode mth) throws JadxException {
IContainer startRegion = mth.getRegion();
if (mth.isNoCode() || startRegion == null) {
return;
}
DepthRegionTraverser.traverse(mth, new ProcessTryCatchRegions(mth), startRegion);
if (mth.getLoopsCount() != 0) {
DepthRegionTraverser.traverse(mth, new ProcessLoopRegions(), startRegion);
}
if (mth.getReturnType().equals(ArgType.VOID)) {
DepthRegionTraverser.traverseAll(mth, new ProcessReturnInsns());
}
}
}
package jadx.core.dex.visitors.regions;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.LoopRegion;
public class ProcessLoopRegions extends AbstractRegionVisitor {
@Override
public void enterRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
}
}
}
......@@ -59,7 +59,8 @@ public class ProcessReturnInsns extends TracedRegionVisitor {
IContainer subBlock = itSubBlock.previous();
if (subBlock == curContainer) {
break;
} else if (RegionUtils.notEmpty(subBlock)) {
} else if (!subBlock.getAttributes().contains(AttributeFlag.RETURN)
&& RegionUtils.notEmpty(subBlock)) {
return false;
}
}
......
......@@ -459,6 +459,9 @@ public class RegionMaker {
if (elseBlock != null) {
if (stack.containsExit(elseBlock)) {
elseBlock = null;
} else if (elseBlock.getAttributes().contains(AttributeFlag.RETURN)) {
out = elseBlock;
elseBlock = null;
}
}
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AttributeFlag;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfRegion;
import jadx.core.dex.regions.LoopRegion;
import jadx.core.dex.regions.Region;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException;
import java.util.List;
/**
* Pack blocks into regions for code generation
*/
......@@ -27,5 +36,51 @@ public class RegionMakerVisitor extends AbstractVisitor {
rm.processExcHandler(handler, state);
}
}
postProcessRegions(mth);
}
private static void postProcessRegions(MethodNode mth) {
// make try-catch regions
DepthRegionTraverser.traverse(mth, new ProcessTryCatchRegions(mth), mth.getRegion());
// merge conditions in loops
if (mth.getLoopsCount() != 0) {
DepthRegionTraverser.traverseAll(mth, new AbstractRegionVisitor() {
@Override
public void enterRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
}
}
});
}
CleanRegions.process(mth);
// mark if-else-if chains
DepthRegionTraverser.traverseAll(mth, new AbstractRegionVisitor() {
@Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (region instanceof IfRegion) {
IfRegion ifregion = (IfRegion) region;
IContainer elsRegion = ifregion.getElseRegion();
if (elsRegion instanceof IfRegion) {
elsRegion.getAttributes().add(AttributeFlag.ELSE_IF_CHAIN);
} else if (elsRegion instanceof Region) {
List<IContainer> subBlocks = ((Region) elsRegion).getSubBlocks();
if (subBlocks.size() == 1 && subBlocks.get(0) instanceof IfRegion) {
subBlocks.get(0).getAttributes().add(AttributeFlag.ELSE_IF_CHAIN);
}
}
}
}
});
// remove useless returns in void methods
if (mth.getReturnType().equals(ArgType.VOID)) {
DepthRegionTraverser.traverseAll(mth, new ProcessReturnInsns());
}
}
}
package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AttributeFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfCondition;
import jadx.core.dex.regions.IfRegion;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.TernaryRegion;
import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.utils.InsnList;
import jadx.core.utils.exceptions.JadxException;
import java.util.List;
public class TernaryVisitor extends AbstractRegionVisitor implements IDexTreeVisitor {
private static final LiteralArg FALSE_ARG = InsnArg.lit(0, ArgType.BOOLEAN);
private static final LiteralArg TRUE_ARG = InsnArg.lit(1, ArgType.BOOLEAN);
@Override
public boolean visit(ClassNode cls) throws JadxException {
return true;
}
@Override
public void visit(MethodNode mth) {
DepthRegionTraverser.traverseAll(mth, this);
}
@Override
public void enterRegion(MethodNode mth, IRegion region) {
if (!(region instanceof IfRegion)) {
return;
}
if (region.getAttributes().contains(AttributeFlag.ELSE_IF_CHAIN)) {
return;
}
IfRegion ifRegion = (IfRegion) region;
IContainer thenRegion = ifRegion.getThenRegion();
IContainer elseRegion = ifRegion.getElseRegion();
if (thenRegion == null || elseRegion == null) {
return;
}
BlockNode tb = getTernaryInsnBlock(thenRegion);
BlockNode eb = getTernaryInsnBlock(elseRegion);
if (tb == null || eb == null) {
return;
}
BlockNode header = ifRegion.getHeader();
InsnNode t = tb.getInstructions().get(0);
InsnNode e = eb.getInstructions().get(0);
if (t.getResult() != null && e.getResult() != null
&& t.getResult().getTypedVar() == e.getResult().getTypedVar()) {
InsnList.remove(tb, t);
InsnList.remove(eb, e);
TernaryInsn ternInsn = new TernaryInsn(ifRegion.getCondition(),
t.getResult(), InsnArg.wrapArg(t), InsnArg.wrapArg(e));
TernaryRegion tern = new TernaryRegion(ifRegion, header);
// TODO: add api for replace regions
ifRegion.setTernRegion(tern);
// remove 'if' instruction
header.getInstructions().clear();
header.getInstructions().add(ternInsn);
// unbind result args
List<InsnArg> useList = ternInsn.getResult().getTypedVar().getUseList();
useList.remove(t.getResult());
useList.remove(e.getResult());
useList.add(ternInsn.getResult());
// shrink method again
CodeShrinker.shrinkMethod(mth);
return;
}
if (!mth.getReturnType().equals(ArgType.VOID)
&& t.getType() == InsnType.RETURN && e.getType() == InsnType.RETURN) {
boolean inverted = false;
InsnArg thenArg = t.getArg(0);
InsnArg elseArg = e.getArg(0);
if (thenArg.equals(FALSE_ARG) && elseArg.equals(TRUE_ARG)) {
inverted = true;
}
InsnList.remove(tb, t);
InsnList.remove(eb, e);
tb.getAttributes().remove(AttributeFlag.RETURN);
eb.getAttributes().remove(AttributeFlag.RETURN);
IfCondition condition = ifRegion.getCondition();
if (inverted) {
condition = condition.invert();
InsnArg tmp = thenArg;
thenArg = elseArg;
elseArg = tmp;
}
TernaryInsn ternInsn = new TernaryInsn(condition, null, thenArg, elseArg);
InsnNode retInsn = new InsnNode(InsnType.RETURN, 1);
retInsn.addArg(InsnArg.wrapArg(ternInsn));
header.getInstructions().clear();
header.getInstructions().add(retInsn);
header.getAttributes().add(AttributeFlag.RETURN);
ifRegion.setTernRegion(new TernaryRegion(ifRegion, header));
CodeShrinker.shrinkMethod(mth);
}
}
private static BlockNode getTernaryInsnBlock(IContainer thenRegion) {
if (thenRegion instanceof Region) {
Region r = (Region) thenRegion;
if (r.getSubBlocks().size() == 1) {
IContainer container = r.getSubBlocks().get(0);
if (container instanceof BlockNode) {
BlockNode block = (BlockNode) container;
if (block.getInstructions().size() == 1) {
return block;
}
}
}
}
return null;
}
}
package jadx.core.utils;
import jadx.core.dex.attributes.AttributeFlag;
import jadx.core.dex.attributes.AttributeType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IContainer;
......@@ -15,14 +16,13 @@ public class RegionUtils {
public static boolean hasExitEdge(IContainer container) {
if (container instanceof BlockNode) {
return ((BlockNode) container).getSuccessors().size() != 0;
BlockNode block = (BlockNode) container;
return block.getSuccessors().size() != 0
&& !block.getAttributes().contains(AttributeFlag.RETURN);
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
List<IContainer> blocks = region.getSubBlocks();
if (blocks.isEmpty()) {
return false;
}
return hasExitEdge(blocks.get(blocks.size() - 1));
return !blocks.isEmpty() && hasExitEdge(blocks.get(blocks.size() - 1));
} else {
throw new JadxRuntimeException("Unknown container type: " + container.getClass());
}
......
package jadx.tests.internal;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestElseIf extends InternalJadxTest {
public static class TestCls {
public int testIfElse(String str) {
int r;
if (str.equals("a")) {
r = 1;
} else if (str.equals("b")) {
r = 2;
} else if (str.equals("3")) {
r = 3;
} else if (str.equals("$")) {
r = 4;
} else {
r = -1;
}
r = r * 10;
return Math.abs(r);
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("} else if (str.equals(\"b\")) {"));
assertThat(code, containsString("} else {"));
assertThat(code, containsString("r = -1;"));
// no ternary operator
assertThat(code, not(containsString("?")));
assertThat(code, not(containsString(":")));
}
}
......@@ -6,6 +6,8 @@ import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.either;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestRedundantBrackets extends InternalJadxTest {
......@@ -47,10 +49,12 @@ public class TestRedundantBrackets extends InternalJadxTest {
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
// assertThat(code, not(containsString("(-1)")));
assertThat(code, containsString("if (obj instanceof String)"));
assertThat(code, not(containsString("(-1)")));
assertThat(code, not(containsString("return;")));
assertThat(code, either(containsString("if (obj instanceof String) {"))
.or(containsString("return (obj instanceof String) ? ")));
assertThat(code, containsString("if (a + b < 10)"));
assertThat(code, containsString("if ((a & b) != 0)"));
assertThat(code, containsString("if (num == 4 || num == 6 || num == 8 || num == 10)"));
......
......@@ -53,10 +53,11 @@ public class TestReturnWrapping extends InternalJadxTest {
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("return 255;"));
assertThat(code, containsString("return arg0 + 1;"));
//assertThat(code, containsString("return Integer.toHexString(i);"));
assertThat(code, containsString("return arg0.toString() + ret.toString();"));
assertThat(code, containsString("return (i > 128) ? arg0.toString() + ret.toString() : Integer.valueOf(i);"));
assertThat(code, containsString("return arg0 + 2;"));
assertThat(code, containsString("arg0 -= 951;"));
}
......
package jadx.tests.internal;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.either;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
public class TestTernary extends InternalJadxTest {
public static class TestCls {
public boolean test1(int a) {
return a != 2;
}
public void test2(int a) {
assertTrue(a == 3);
}
public int test3(int a) {
return a > 0 ? 1 : (a + 2) * 3;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, not(containsString("else")));
assertThat(code, containsString("return a != 2;"));
assertThat(code, containsString("assertTrue(a == 3)"));
assertThat(code, either(containsString("return a > 0 ? 1 : (a + 2) * 3;"))
.or(containsString("return (a > 0) ? 1 : (a + 2) * 3;")));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment