Commit bb4ef4f0 authored by Skylot's avatar Skylot

core: simplify conditions

parent fd00330e
......@@ -425,9 +425,11 @@ public class InsnGen {
case IF:
assert isFallback() : "if insn in not fallback mode";
IfNode ifInsn = (IfNode) insn;
String cond = arg(insn.getArg(0)) + " " + ifInsn.getOp().getSymbol() + " "
+ (ifInsn.isZeroCmp() ? "0" : arg(insn.getArg(1)));
code.add("if (").add(cond).add(") goto ").add(MethodGen.getLabelName(ifInsn.getTarget()));
code.add("if (");
code.add(arg(insn.getArg(0))).add(' ');
code.add(ifInsn.getOp().getSymbol()).add(' ');
code.add(arg(insn.getArg(1)));
code.add(") goto ").add(MethodGen.getLabelName(ifInsn.getTarget()));
break;
case GOTO:
......
......@@ -8,11 +8,11 @@ public class GotoNode extends InsnNode {
protected int target;
public GotoNode(int target) {
this(InsnType.GOTO, target);
this(InsnType.GOTO, target, 0);
}
protected GotoNode(InsnType type, int target) {
super(type);
protected GotoNode(InsnType type, int target, int argsCount) {
super(type, argsCount);
this.target = target;
}
......
......@@ -2,7 +2,6 @@ package jadx.core.dex.instructions;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.utils.InsnUtils;
......@@ -14,37 +13,32 @@ import static jadx.core.utils.BlockUtils.selectOther;
public class IfNode extends GotoNode {
protected boolean zeroCmp;
private static final ArgType ARG_TYPE = ArgType.unknown(
PrimitiveType.INT, PrimitiveType.OBJECT, PrimitiveType.ARRAY,
PrimitiveType.BOOLEAN, PrimitiveType.SHORT, PrimitiveType.CHAR);
protected IfOp op;
private BlockNode thenBlock;
private BlockNode elseBlock;
public IfNode(DecodedInstruction insn, IfOp op) {
super(InsnType.IF, insn.getTarget());
this.op = op;
ArgType type = ArgType.unknown(
PrimitiveType.INT, PrimitiveType.OBJECT, PrimitiveType.ARRAY,
PrimitiveType.BOOLEAN, PrimitiveType.SHORT, PrimitiveType.CHAR);
this(op, insn.getTarget(),
InsnArg.reg(insn, 0, ARG_TYPE),
insn.getRegisterCount() == 1 ? InsnArg.lit(0, ARG_TYPE) : InsnArg.reg(insn, 1, ARG_TYPE));
}
addReg(insn, 0, type);
if (insn.getRegisterCount() == 1) {
zeroCmp = true;
} else {
zeroCmp = false;
addReg(insn, 1, type);
}
public IfNode(IfOp op, int targetOffset, InsnArg arg1, InsnArg arg2) {
super(InsnType.IF, targetOffset, 2);
this.op = op;
addArg(arg1);
addArg(arg2);
}
public IfOp getOp() {
return op;
}
public boolean isZeroCmp() {
return zeroCmp;
}
public void invertCondition() {
op = op.invert();
BlockNode tmp = thenBlock;
......@@ -53,17 +47,10 @@ public class IfNode extends GotoNode {
target = thenBlock.getStartOffset();
}
public void changeCondition(InsnArg arg1, InsnArg arg2, IfOp op) {
public void changeCondition(IfOp op, InsnArg arg1, InsnArg arg2) {
this.op = op;
this.zeroCmp = arg2.isLiteral() && ((LiteralArg) arg2).getLiteral() == 0;
setArg(0, arg1);
if (!zeroCmp) {
if (getArgsCount() == 2) {
setArg(1, arg2);
} else {
addArg(arg2);
}
}
setArg(1, arg2);
}
public void initBlocks(BlockNode curBlock) {
......@@ -87,8 +74,7 @@ public class IfNode extends GotoNode {
public String toString() {
return InsnUtils.formatOffset(offset) + ": "
+ InsnUtils.insnTypeToString(insnType)
+ getArg(0) + " " + op.getSymbol()
+ " " + (zeroCmp ? "0" : getArg(1))
+ getArg(0) + " " + op.getSymbol() + " " + getArg(1)
+ " -> " + (thenBlock != null ? thenBlock : InsnUtils.formatOffset(target));
}
}
......@@ -299,7 +299,7 @@ public abstract class ArgType {
}
public String getObject() {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException("ArgType.getObject()");
}
public boolean isObject() {
......
......@@ -5,6 +5,9 @@ import jadx.core.utils.exceptions.JadxRuntimeException;
public final class LiteralArg extends InsnArg {
public static final LiteralArg TRUE = new LiteralArg(1, ArgType.BOOLEAN);
public static final LiteralArg FALSE = new LiteralArg(0, ArgType.BOOLEAN);
private final long literal;
public LiteralArg(long value, ArgType type) {
......@@ -62,7 +65,11 @@ public final class LiteralArg extends InsnArg {
@Override
public String toString() {
try {
return "(" + TypeGen.literalToString(literal, getType()) + " " + typedVar + ")";
String value = TypeGen.literalToString(literal, getType());
if (getType().equals(ArgType.BOOLEAN) && (value.equals("true") || value.equals("false"))) {
return value;
}
return "(" + value + " " + typedVar + ")";
} catch (JadxRuntimeException ex) {
// can't convert literal to string
return "(" + literal + " " + typedVar + ")";
......
......@@ -3,6 +3,7 @@ package jadx.core.dex.regions;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
public final class Compare {
private final IfNode insn;
......@@ -20,11 +21,7 @@ public final class Compare {
}
public InsnArg getB() {
if (insn.isZeroCmp()) {
return InsnArg.lit(0, getA().getType());
} else {
return insn.getArg(1);
}
return insn.getArg(1);
}
public Compare invert() {
......@@ -32,6 +29,15 @@ public final class Compare {
return this;
}
/**
* Change 'a != false' to 'a == true'
*/
public void normalize() {
if (getOp() == IfOp.NE && getB().isLiteral() && getB().equals(LiteralArg.FALSE)) {
insn.changeCondition(IfOp.EQ, getA(), LiteralArg.TRUE);
}
}
@Override
public String toString() {
return getA() + " " + getOp().getSymbol() + " " + getB();
......
package jadx.core.dex.regions;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
......@@ -52,7 +56,7 @@ public final class IfCondition {
private IfCondition(Compare compare) {
this.mode = Mode.COMPARE;
this.compare = compare;
this.args = null;
this.args = Collections.emptyList();
}
private IfCondition(Mode mode, List<IfCondition> args) {
......@@ -64,7 +68,11 @@ public final class IfCondition {
private IfCondition(IfCondition c) {
this.mode = c.mode;
this.compare = c.compare;
this.args = new ArrayList<IfCondition>(c.args);
if (c.mode == Mode.COMPARE) {
this.args = Collections.emptyList();
} else {
this.args = new ArrayList<IfCondition>(c.args);
}
}
public Mode getMode() {
......@@ -75,6 +83,14 @@ public final class IfCondition {
return args;
}
public IfCondition first() {
return args.get(0);
}
public IfCondition second() {
return args.get(1);
}
public void addArg(IfCondition c) {
args.add(c);
}
......@@ -87,23 +103,80 @@ public final class IfCondition {
return compare;
}
public IfCondition invert() {
public static IfCondition invert(IfCondition cond) {
Mode mode = cond.getMode();
switch (mode) {
case COMPARE:
return new IfCondition(compare.invert());
return new IfCondition(cond.getCompare().invert());
case NOT:
return new IfCondition(args.get(0));
return cond.first();
case AND:
case OR:
List<IfCondition> args = cond.getArgs();
List<IfCondition> newArgs = new ArrayList<IfCondition>(args.size());
for (IfCondition arg : args) {
newArgs.add(arg.invert());
newArgs.add(invert(arg));
}
return new IfCondition(mode == Mode.AND ? Mode.OR : Mode.AND, newArgs);
}
throw new JadxRuntimeException("Unknown mode for invert: " + mode);
}
public static IfCondition not(IfCondition cond) {
if (cond.getMode() == Mode.NOT) {
return cond.first();
}
return new IfCondition(Mode.NOT, Collections.singletonList(cond));
}
public static IfCondition simplify(IfCondition cond) {
if (cond.isCompare()) {
Compare c = cond.getCompare();
if (c.getOp() == IfOp.EQ && c.getB().isLiteral() && c.getB().equals(LiteralArg.FALSE)) {
return not(new IfCondition(c.invert()));
} else {
c.normalize();
}
return cond;
}
List<IfCondition> args = null;
for (int i = 0; i < cond.getArgs().size(); i++) {
IfCondition arg = cond.getArgs().get(i);
IfCondition simpl = simplify(arg);
if (simpl != arg) {
if (args == null) {
args = new ArrayList<IfCondition>(cond.getArgs());
}
args.set(i, simpl);
}
}
if (args != null) {
// arguments was changed
cond = new IfCondition(cond.getMode(), args);
}
if (cond.getMode() == Mode.NOT && cond.first().getMode() == Mode.NOT) {
cond = cond.first().first();
}
// for condition with a lot of negations => make invert
if (cond.getMode() == Mode.OR || cond.getMode() == Mode.AND) {
int count = cond.getArgs().size();
if (count > 1) {
int negCount = 0;
for (IfCondition arg : cond.getArgs()) {
if (arg.getMode() == Mode.NOT
|| (arg.isCompare() && arg.getCompare().getOp() == IfOp.NE)) {
negCount++;
}
}
if (negCount > count / 2) {
return not(invert(cond));
}
}
}
return cond;
}
public List<RegisterArg> getRegisterArgs() {
List<RegisterArg> list = new LinkedList<RegisterArg>();
if (mode == Mode.COMPARE) {
......@@ -129,11 +202,21 @@ public final class IfCondition {
case COMPARE:
return compare.toString();
case NOT:
return "!" + args;
return "!" + first();
case AND:
return "&& " + args;
case OR:
return "|| " + args;
String op = mode == Mode.OR ? " || " : " && ";
StringBuilder sb = new StringBuilder();
sb.append('(');
for (Iterator<IfCondition> it = args.iterator(); it.hasNext(); ) {
IfCondition arg = it.next();
sb.append(arg);
if (it.hasNext()) {
sb.append(op);
}
}
sb.append(')');
return sb.toString();
}
return "??";
}
......
......@@ -61,6 +61,23 @@ public final class IfRegion extends AbstractRegion {
return ternRegion;
}
public boolean simplifyCondition() {
IfCondition cond = IfCondition.simplify(condition);
if (cond != condition) {
condition = cond;
return true;
}
return false;
}
public void invert() {
condition = IfCondition.invert(condition);
// swap regions
IContainer tmp = thenRegion;
thenRegion = elseRegion;
elseRegion = tmp;
}
@Override
public List<IContainer> getSubBlocks() {
if (ternRegion != null) {
......
package jadx.core.dex.visitors;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
......@@ -124,15 +123,12 @@ public class ConstInlinerVisitor extends AbstractVisitor {
}
case IF: {
IfNode ifnode = (IfNode) insn;
if (!ifnode.isZeroCmp()) {
InsnArg arg0 = insn.getArg(0);
InsnArg arg1 = insn.getArg(1);
if (arg0 == litArg) {
arg0.merge(arg1);
} else {
arg1.merge(arg0);
}
InsnArg arg0 = insn.getArg(0);
InsnArg arg1 = insn.getArg(1);
if (arg0 == litArg) {
arg0.merge(arg1);
} else {
arg1.merge(arg0);
}
break;
}
......
......@@ -96,9 +96,9 @@ public class SimplifyVisitor extends AbstractVisitor {
if (f.isInsnWrap()) {
InsnNode wi = ((InsnWrapArg) f).getWrapInsn();
if (wi.getType() == InsnType.CMP_L || wi.getType() == InsnType.CMP_G) {
if (ifb.isZeroCmp()
|| ((LiteralArg) ifb.getArg(1)).getLiteral() == 0) {
ifb.changeCondition(wi.getArg(0), wi.getArg(1), ifb.getOp());
if (ifb.getArg(1).isLiteral()
&& ((LiteralArg) ifb.getArg(1)).getLiteral() == 0) {
ifb.changeCondition(ifb.getOp(), wi.getArg(0), wi.getArg(1));
} else {
LOG.warn("TODO: cmp" + ifb);
}
......
......@@ -208,7 +208,7 @@ public class RegionMaker {
condBlock = mergedIf.getIfnode();
if (!loop.getLoopBlocks().contains(mergedIf.getThenBlock())) {
// invert loop condition if it points to exit
loopRegion.setCondition(mergedIf.getCondition().invert());
loopRegion.setCondition(IfCondition.invert(mergedIf.getCondition()));
bThen = mergedIf.getElseBlock();
} else {
loopRegion.setCondition(mergedIf.getCondition());
......@@ -303,7 +303,7 @@ public class RegionMaker {
}
}
if (bThen != loopBody) {
loopRegion.setCondition(loopRegion.getCondition().invert());
loopRegion.setCondition(IfCondition.invert(loopRegion.getCondition()));
}
out = selectOther(loopBody, condBlock.getSuccessors());
AttributesList outAttrs = out.getAttributes();
......
......@@ -59,21 +59,12 @@ public class RegionMakerVisitor extends AbstractVisitor {
CleanRegions.process(mth);
// mark if-else-if chains
DepthRegionTraverser.traverseAll(mth, new AbstractRegionVisitor() {
@Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (region instanceof IfRegion) {
IfRegion ifregion = (IfRegion) region;
IContainer elsRegion = ifregion.getElseRegion();
if (elsRegion instanceof IfRegion) {
elsRegion.getAttributes().add(AttributeFlag.ELSE_IF_CHAIN);
} else if (elsRegion instanceof Region) {
List<IContainer> subBlocks = ((Region) elsRegion).getSubBlocks();
if (subBlocks.size() == 1 && subBlocks.get(0) instanceof IfRegion) {
subBlocks.get(0).getAttributes().add(AttributeFlag.ELSE_IF_CHAIN);
}
}
processIfRegion((IfRegion) region);
}
}
});
......@@ -83,4 +74,24 @@ public class RegionMakerVisitor extends AbstractVisitor {
DepthRegionTraverser.traverseAll(mth, new ProcessReturnInsns());
}
}
private static void processIfRegion(IfRegion ifRegion) {
if (ifRegion.simplifyCondition()) {
// IfCondition condition = ifRegion.getCondition();
// if (condition.getMode() == IfCondition.Mode.NOT) {
// ifRegion.invert();
// }
}
// mark if-else-if chains
IContainer elsRegion = ifRegion.getElseRegion();
if (elsRegion instanceof IfRegion) {
elsRegion.getAttributes().add(AttributeFlag.ELSE_IF_CHAIN);
} else if (elsRegion instanceof Region) {
List<IContainer> subBlocks = ((Region) elsRegion).getSubBlocks();
if (subBlocks.size() == 1 && subBlocks.get(0) instanceof IfRegion) {
subBlocks.get(0).getAttributes().add(AttributeFlag.ELSE_IF_CHAIN);
}
}
}
}
......@@ -102,7 +102,7 @@ public class TernaryVisitor extends AbstractRegionVisitor implements IDexTreeVis
IfCondition condition = ifRegion.getCondition();
if (inverted) {
condition = condition.invert();
condition = IfCondition.invert(condition);
InsnArg tmp = thenArg;
thenArg = elseArg;
elseArg = tmp;
......
package jadx.core.dex.visitors.typeresolver.finish;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
......@@ -56,14 +55,11 @@ public class PostTypeResolver {
case IF: {
boolean change = false;
IfNode ifnode = (IfNode) insn;
if (!ifnode.isZeroCmp()) {
if (insn.getArg(1).merge(insn.getArg(0))) {
change = true;
}
if (insn.getArg(0).merge(insn.getArg(1))) {
change = true;
}
if (insn.getArg(1).merge(insn.getArg(0))) {
change = true;
}
if (insn.getArg(0).merge(insn.getArg(1))) {
change = true;
}
return change;
}
......
package jadx.tests.functional;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.regions.Compare;
import jadx.core.dex.regions.IfCondition;
import org.junit.Test;
import static jadx.core.dex.regions.IfCondition.Mode;
import static jadx.core.dex.regions.IfCondition.merge;
import static jadx.core.dex.regions.IfCondition.not;
import static jadx.core.dex.regions.IfCondition.simplify;
import static org.junit.Assert.assertEquals;
public class TestIfCondition {
private static IfCondition makeCondition(IfOp op, InsnArg a, InsnArg b) {
return IfCondition.fromIfNode(new IfNode(op, -1, a, b));
}
private static IfCondition makeSimpleCondition() {
return makeCondition(IfOp.EQ, mockArg(), LiteralArg.TRUE);
}
private static IfCondition makeNegCondition() {
return makeCondition(IfOp.NE, mockArg(), LiteralArg.TRUE);
}
private static InsnArg mockArg() {
return InsnArg.reg(0, ArgType.INT);
}
@Test
public void testNormalize() {
// 'a != false' => 'a == true'
InsnArg a = mockArg();
IfCondition c = makeCondition(IfOp.NE, a, LiteralArg.FALSE);
IfCondition simp = simplify(c);
assertEquals(simp.getMode(), Mode.COMPARE);
Compare compare = simp.getCompare();
assertEquals(compare.getA(), a);
assertEquals(compare.getB(), LiteralArg.TRUE);
}
@Test
public void testMerge() {
IfCondition a = makeSimpleCondition();
IfCondition b = makeSimpleCondition();
IfCondition c = merge(Mode.OR, a, b);
assertEquals(c.getMode(), Mode.OR);
assertEquals(c.first(), a);
assertEquals(c.second(), b);
}
@Test
public void testSimplifyNot() {
// !(!a) => a
IfCondition a = not(not(makeSimpleCondition()));
assertEquals(simplify(a), a);
}
@Test
public void testSimplifyNot2() {
// !(!a) => a
IfCondition a = not(makeNegCondition());
assertEquals(simplify(a), a);
}
@Test
public void testSimplify() {
// '!(!a || !b)' => 'a && b'
IfCondition a = makeSimpleCondition();
IfCondition b = makeSimpleCondition();
IfCondition c = not(merge(Mode.OR, not(a), not(b)));
IfCondition simp = simplify(c);
assertEquals(simp.getMode(), Mode.AND);
assertEquals(simp.first(), a);
assertEquals(simp.second(), b);
}
@Test
public void testSimplify2() {
// '(!a || !b) && !c' => '!((a && b) || c)'
IfCondition a = makeSimpleCondition();
IfCondition b = makeSimpleCondition();
IfCondition c = makeSimpleCondition();
IfCondition cond = merge(Mode.AND, merge(Mode.OR, not(a), not(b)), not(c));
IfCondition simp = simplify(cond);
assertEquals(simp.getMode(), Mode.NOT);
IfCondition f = simp.first();
assertEquals(f.getMode(), Mode.OR);
assertEquals(f.first().getMode(), Mode.AND);
assertEquals(f.first().first(), a);
assertEquals(f.first().second(), b);
assertEquals(f.second(), c);
}
}
package jadx.tests.internal;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestConditions extends InternalJadxTest {
public static class TestCls {
private boolean f1(boolean a, boolean b, boolean c) {
return (a && b) || c;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, not(containsString("(!a || !b) && !c")));
assertThat(code, containsString("(a && b) || c"));
// assertThat(code, containsString("return (a && b) || c;"));
}
}
......@@ -9,7 +9,7 @@ import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestTmp2 extends InternalJadxTest {
public class TestConditions2 extends InternalJadxTest {
public static class TestCls extends Exception {
int c;
......
......@@ -60,6 +60,10 @@ public class TestConditions extends AbstractTest {
return num > 5 && (num < 10 || num == 7);
}
private boolean test6(boolean a, boolean b, boolean c) {
return (a && b) || c;
}
public boolean accept(String name) {
return name.startsWith("Test") && name.endsWith(".class") && !name.contains("$");
}
......@@ -87,6 +91,10 @@ public class TestConditions extends AbstractTest {
assertTrue(test5(6));
assertTrue(test5(7));
assertTrue(test5(8));
assertTrue(test6(true, true, false));
assertTrue(test6(false, false, true));
assertFalse(test6(true, false, false));
return true;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment