Commit 5545a94a authored by Skylot's avatar Skylot

core: process nested ternary operators

parent 9e811d95
......@@ -21,6 +21,14 @@ public class PhiInsn extends InsnNode {
return (RegisterArg) super.getArg(n);
}
public boolean removeArg(RegisterArg arg) {
boolean isRemoved = super.removeArg(arg);
if (isRemoved) {
arg.getSVar().setUsedInPhi(null);
}
return isRemoved;
}
@Override
public String toString() {
return "PHI: " + getResult() + " = " + Utils.listToString(getArguments());
......
......@@ -101,6 +101,17 @@ public class InsnNode extends LineAttrNode {
return false;
}
protected boolean removeArg(InsnArg arg) {
int count = getArgsCount();
for (int i = 0; i < count; i++) {
if (arg == arguments.get(i)) {
arguments.remove(i);
return true;
}
}
return false;
}
protected void addReg(DecodedInstruction insn, int i, ArgType type) {
addArg(InsnArg.reg(insn, i, type));
}
......
......@@ -16,8 +16,6 @@ public final class IfRegion extends AbstractRegion {
private IContainer thenRegion;
private IContainer elseRegion;
private TernaryRegion ternRegion;
public IfRegion(IRegion parent, BlockNode header) {
super(parent);
assert header.getInstructions().size() == 1;
......@@ -53,14 +51,6 @@ public final class IfRegion extends AbstractRegion {
return header;
}
public void setTernRegion(TernaryRegion ternRegion) {
this.ternRegion = ternRegion;
}
public TernaryRegion getTernRegion() {
return ternRegion;
}
public boolean simplifyCondition() {
IfCondition cond = IfCondition.simplify(condition);
if (cond != condition) {
......@@ -87,9 +77,6 @@ public final class IfRegion extends AbstractRegion {
@Override
public List<IContainer> getSubBlocks() {
if (ternRegion != null) {
return ternRegion.getSubBlocks();
}
ArrayList<IContainer> all = new ArrayList<IContainer>(3);
all.add(header);
if (thenRegion != null) {
......@@ -116,9 +103,6 @@ public final class IfRegion extends AbstractRegion {
@Override
public String baseString() {
if (ternRegion != null) {
return ternRegion.baseString();
}
StringBuilder sb = new StringBuilder();
if (thenRegion != null) {
sb.append(thenRegion.baseString());
......@@ -131,9 +115,6 @@ public final class IfRegion extends AbstractRegion {
@Override
public String toString() {
if (ternRegion != null) {
return ternRegion.toString();
}
return "IF(" + condition + ") then " + thenRegion + " else " + elseRegion;
}
}
package jadx.core.dex.regions;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import java.util.Collections;
import java.util.List;
public final class TernaryRegion extends AbstractRegion {
private final IBlock container;
public TernaryRegion(IRegion parent, BlockNode block) {
super(parent);
this.container = block;
}
public IBlock getBlock() {
return container;
}
@Override
public List<IContainer> getSubBlocks() {
return Collections.singletonList((IContainer) container);
}
@Override
public String baseString() {
return container.baseString();
}
@Override
public String toString() {
return "TERN:" + container;
}
}
......@@ -15,9 +15,11 @@ import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfCondition;
import java.util.ArrayList;
import java.util.Collections;
......@@ -62,6 +64,9 @@ public class SimplifyVisitor extends AbstractVisitor {
case IF:
simplifyIf((IfNode) insn);
break;
case TERNARY:
simplifyTernary((TernaryInsn)insn);
break;
case INVOKE:
return convertInvoke(mth, insn);
......@@ -105,6 +110,16 @@ public class SimplifyVisitor extends AbstractVisitor {
}
}
/**
* Simplify condition in ternary operation
*/
private static void simplifyTernary(TernaryInsn insn) {
IfCondition condition = insn.getCondition();
if (condition.isCompare()) {
simplifyIf(condition.getCompare().getInsn());
}
}
private static InsnNode convertInvoke(MethodNode mth, InsnNode insn) {
MethodInfo callMth = ((InvokeNode) insn).getCallMth();
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
......
......@@ -50,7 +50,7 @@ public class IfMakerHelper {
boolean badThen = !allPathsFromIf(thenBlock, info);
boolean badElse = !allPathsFromIf(elseBlock, info);
if (badThen && badElse) {
LOG.debug("Stop processing blocks after 'if': {}, method: {}", info, mth);
LOG.debug("Stop processing blocks after 'if': {}, method: {}", info.getIfBlock(), mth);
return null;
}
if (badElse) {
......
......@@ -21,6 +21,16 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
@Override
public void visit(MethodNode mth) {
// collapse ternary operators
DepthRegionTraversal.traverseAllIterative(mth, new IRegionIterativeVisitor() {
@Override
public boolean visitRegion(MethodNode mth, IRegion region) {
if (region instanceof IfRegion) {
return TernaryMod.makeTernaryInsn(mth, (IfRegion) region);
}
return false;
}
});
DepthRegionTraversal.traverseAll(mth, this);
DepthRegionTraversal.traverseAllIterative(mth, this);
}
......@@ -53,8 +63,6 @@ public class IfRegionVisitor extends AbstractVisitor implements IRegionVisitor,
moveReturnToThenBlock(mth, ifRegion);
moveBreakToThenBlock(ifRegion);
markElseIfChains(ifRegion);
TernaryMod.makeTernaryInsn(mth, ifRegion);
}
private static void simplifyIfCondition(IfRegion ifRegion) {
......
......@@ -106,7 +106,7 @@ public class ProcessTryCatchRegions extends AbstractRegionVisitor {
if (region.getSubBlocks().contains(dominator)) {
TryCatchBlock tb = tryBlocksMap.get(dominator);
if (!wrapBlocks(region, tb, dominator)) {
LOG.warn("Can't wrap try/catch for {}, method: {}", dominator, mth);
LOG.warn("Can't wrap try/catch for {}, method: {}", region, mth);
mth.add(AFlag.INCONSISTENT_CODE);
}
tryBlocksMap.remove(dominator);
......
......@@ -2,8 +2,10 @@ package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
......@@ -12,46 +14,73 @@ import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfRegion;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.TernaryRegion;
import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.utils.InsnList;
import java.util.HashMap;
import java.util.Map;
public class TernaryMod {
private TernaryMod() {
}
static void makeTernaryInsn(MethodNode mth, IfRegion ifRegion) {
static boolean makeTernaryInsn(MethodNode mth, IfRegion ifRegion) {
if (ifRegion.contains(AFlag.ELSE_IF_CHAIN)) {
return;
return false;
}
IContainer thenRegion = ifRegion.getThenRegion();
IContainer elseRegion = ifRegion.getElseRegion();
if (thenRegion == null || elseRegion == null) {
return;
return false;
}
BlockNode tb = getTernaryInsnBlock(thenRegion);
BlockNode eb = getTernaryInsnBlock(elseRegion);
if (tb == null || eb == null) {
return;
return false;
}
BlockNode header = ifRegion.getHeader();
InsnNode t = tb.getInstructions().get(0);
InsnNode e = eb.getInstructions().get(0);
if (t.getResult() != null && e.getResult() != null
&& t.getResult().equalRegisterAndType(e.getResult())
&& t.getResult().getSVar().isUsedInPhi()) {
if (t.getSourceLine() != e.getSourceLine()) {
if (t.getSourceLine() != 0 && e.getSourceLine() != 0) {
// sometimes source lines incorrect
if (!checkLineStats(t, e)) {
return false;
}
} else {
// no debug info
if (containsTernary(t) || containsTernary(e)) {
// don't make nested ternary by default
// TODO: add addition checks
return false;
}
}
}
if (t.getResult() != null && e.getResult() != null) {
if (!t.getResult().equalRegisterAndType(e.getResult())
|| !t.getResult().getSVar().isUsedInPhi()) {
return false;
}
if (!ifRegion.getParent().replaceSubBlock(ifRegion, header)) {
return false;
}
InsnList.remove(tb, t);
InsnList.remove(eb, e);
RegisterArg resArg = t.getResult().getSVar().getUsedInPhi().getResult();
RegisterArg resArg;
PhiInsn phi = t.getResult().getSVar().getUsedInPhi();
if (phi.getArgsCount() == 2) {
resArg = phi.getResult();
} else {
resArg = t.getResult();
phi.removeArg(e.getResult());
}
TernaryInsn ternInsn = new TernaryInsn(ifRegion.getCondition(),
resArg, InsnArg.wrapArg(t), InsnArg.wrapArg(e));
ternInsn.setSourceLine(t.getSourceLine());
TernaryRegion tern = new TernaryRegion(ifRegion, header);
// TODO: add api for replace regions
ifRegion.setTernRegion(tern);
// remove 'if' instruction
header.getInstructions().clear();
......@@ -59,11 +88,15 @@ public class TernaryMod {
// shrink method again
CodeShrinker.shrinkMethod(mth);
return;
return true;
}
if (!mth.getReturnType().equals(ArgType.VOID)
&& t.getType() == InsnType.RETURN && e.getType() == InsnType.RETURN) {
if (!ifRegion.getParent().replaceSubBlock(ifRegion, header)) {
return false;
}
InsnList.remove(tb, t);
InsnList.remove(eb, e);
tb.remove(AFlag.RETURN);
......@@ -78,10 +111,10 @@ public class TernaryMod {
header.getInstructions().add(retInsn);
header.add(AFlag.RETURN);
ifRegion.setTernRegion(new TernaryRegion(ifRegion, header));
CodeShrinker.shrinkMethod(mth);
return true;
}
return false;
}
private static BlockNode getTernaryInsnBlock(IContainer thenRegion) {
......@@ -99,4 +132,55 @@ public class TernaryMod {
}
return null;
}
private static boolean containsTernary(InsnNode insn) {
if (insn.getType() == InsnType.TERNARY) {
return true;
}
for (int i = 0; i < insn.getArgsCount(); i++) {
InsnArg arg = insn.getArg(i);
if (arg.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
if (containsTernary(wrapInsn)) {
return true;
}
}
}
return false;
}
/**
* Return 'true' if there are several args with same source lines
*/
private static boolean checkLineStats(InsnNode t, InsnNode e) {
if (t.getResult() == null || e.getResult() == null) {
return false;
}
PhiInsn tPhi = t.getResult().getSVar().getUsedInPhi();
PhiInsn ePhi = e.getResult().getSVar().getUsedInPhi();
if (tPhi == null || ePhi == null || tPhi != ePhi) {
return false;
}
Map<Integer, Integer> map = new HashMap<Integer, Integer>(tPhi.getArgsCount());
for (InsnArg arg : tPhi.getArguments()) {
if (!arg.isRegister()) {
continue;
}
int sourceLine = ((RegisterArg) arg).getAssignInsn().getSourceLine();
if (sourceLine != 0) {
Integer count = map.get(sourceLine);
if (count != null) {
map.put(sourceLine, count + 1);
} else {
map.put(sourceLine, 1);
}
}
}
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
if (entry.getValue() >= 2) {
return true;
}
}
return false;
}
}
......@@ -52,8 +52,8 @@ public class TestRedundantBrackets extends InternalJadxTest {
assertThat(code, not(containsString("return;")));
assertThat(code, containsString("return obj instanceof String ? ((String) obj).length() : 0;"));
assertThat(code, containsString("if (a + b < 10)"));
// assertThat(code, containsString("if ((a & b) != 0)"));
assertThat(code, containsString("a + b < 10"));
assertThat(code, containsString("(a & b) != 0"));
assertThat(code, containsString("if (num == 4 || num == 6 || num == 8 || num == 10)"));
assertThat(code, containsString("a[1] = n * 2;"));
......
package jadx.tests.internal.conditions;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestConditions14 extends InternalJadxTest {
public static class TestCls {
public static boolean test(Object a, Object b) {
boolean r = a == null ? b != null : !a.equals(b);
if (r) {
return false;
}
System.out.println("1");
return true;
}
// public static boolean test2(Object a, Object b) {
// if (a == null ? b != null : !a.equals(b)) {
// return false;
// }
// System.out.println("2");
// return true;
// }
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("boolean r = a == null ? b != null : !a.equals(b);"));
assertThat(code, containsOne("if (r) {"));
assertThat(code, containsOne("System.out.println(\"1\");"));
}
}
......@@ -5,6 +5,7 @@ import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
......@@ -36,11 +37,11 @@ public class TestElseIf extends InternalJadxTest {
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsString("} else if (str.equals(\"b\")) {"));
assertThat(code, containsString("} else {"));
assertThat(code, containsString("int r;"));
assertThat(code, containsString("r = 1;"));
assertThat(code, containsString("r = -1;"));
assertThat(code, containsOne("} else if (str.equals(\"b\")) {"));
assertThat(code, containsOne("} else {"));
assertThat(code, containsOne("int r;"));
assertThat(code, containsOne("r = 1;"));
assertThat(code, containsOne("r = -1;"));
// no ternary operator
assertThat(code, not(containsString("?")));
assertThat(code, not(containsString(":")));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment