Commit 5a68d3be authored by Skylot's avatar Skylot

core: restore for-each loop over array

parent 195eeceb
......@@ -179,7 +179,7 @@ public class InsnGen {
mgen.getClassGen().useClass(code, cls);
}
private void useType(CodeWriter code, ArgType type) {
protected void useType(CodeWriter code, ArgType type) {
mgen.getClassGen().useType(code, type);
}
......@@ -200,7 +200,7 @@ public class InsnGen {
if (flag != Flags.INLINE) {
code.startLineWithNum(insn.getSourceLine());
}
if (insn.getResult() != null && insn.getType() != InsnType.ARITH_ONEARG) {
if (insn.getResult() != null && !insn.contains(AFlag.ARITH_ONEARG)) {
assignVar(code, insn);
code.add(" = ");
}
......@@ -257,10 +257,6 @@ public class InsnGen {
makeArith((ArithNode) insn, code, state);
break;
case ARITH_ONEARG:
makeArithOneArg((ArithNode) insn, code);
break;
case NEG: {
boolean wrap = state.contains(Flags.BODY_ONLY);
if (wrap) {
......@@ -761,6 +757,10 @@ public class InsnGen {
}
private void makeArith(ArithNode insn, CodeWriter code, EnumSet<Flags> state) throws CodegenException {
if (insn.contains(AFlag.ARITH_ONEARG)) {
makeArithOneArg(insn, code);
return;
}
// wrap insn in brackets for save correct operation order
boolean wrap = state.contains(Flags.BODY_ONLY) && !insn.contains(AFlag.DONT_WRAP);
if (wrap) {
......@@ -778,7 +778,7 @@ public class InsnGen {
private void makeArithOneArg(ArithNode insn, CodeWriter code) throws CodegenException {
ArithOp op = insn.getOp();
InsnArg arg = insn.getArg(0);
InsnArg arg = insn.getArg(1);
// "++" or "--"
if (arg.isLiteral() && (op == ArithOp.ADD || op == ArithOp.SUB)) {
LiteralArg lit = (LiteralArg) arg;
......
......@@ -19,6 +19,7 @@ import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.SynchronizedRegion;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.ForEachLoop;
import jadx.core.dex.regions.loops.IndexLoop;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.regions.loops.LoopType;
......@@ -187,6 +188,17 @@ public class RegionGen extends InsnGen {
code.startLine('}');
return code;
}
if (type instanceof ForEachLoop) {
ForEachLoop forEachLoop = (ForEachLoop) type;
code.startLine("for (");
declareVar(code, forEachLoop.getVarArg());
code.add(" : ");
addArg(code, forEachLoop.getIterableArg(), false);
code.add(") {");
makeRegionIndent(code, region.getBody());
code.startLine('}');
return code;
}
throw new JadxRuntimeException("Unknown loop type: " + type.getClass());
}
if (region.isConditionAtEnd()) {
......
......@@ -24,6 +24,7 @@ public enum AFlag {
ELSE_IF_CHAIN,
WRAPPED,
ARITH_ONEARG,
INCONSISTENT_CODE, // warning about incorrect decompilation
}
package jadx.core.dex.instructions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
......@@ -51,10 +52,8 @@ public class ArithNode extends InsnNode {
}
public ArithNode(ArithOp op, RegisterArg res, InsnArg a) {
super(InsnType.ARITH_ONEARG, 1);
this.op = op;
setResult(res);
addArg(a);
this(op, res, res, a);
add(AFlag.ARITH_ONEARG);
}
public ArithOp getOp() {
......@@ -85,7 +84,7 @@ public class ArithNode extends InsnNode {
+ getResult() + " = "
+ getArg(0) + " "
+ op.getSymbol() + " "
+ (getArgsCount() == 2 ? getArg(1) : "");
+ getArg(1);
}
}
......@@ -54,7 +54,6 @@ public enum InsnType {
CONTINUE,
STR_CONCAT, // strings concatenation
ARITH_ONEARG,
TERNARY,
ARGS, // just generate arguments
......
......@@ -35,6 +35,12 @@ public class InsnNode extends LineAttrNode {
}
}
public static InsnNode wrapArg(InsnArg arg) {
InsnNode insn = new InsnNode(InsnType.ARGS, 1);
insn.addArg(arg);
return insn;
}
public void setResult(RegisterArg res) {
if (res != null) {
res.setParentInsn(this);
......
package jadx.core.dex.regions.loops;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
public class ForEachLoop extends LoopType {
private final RegisterArg varArg;
private final InsnArg iterableArg;
public ForEachLoop(RegisterArg varArg, InsnArg iterableArg) {
this.varArg = varArg;
this.iterableArg = iterableArg;
}
public RegisterArg getVarArg() {
return varArg;
}
public InsnArg getIterableArg() {
return iterableArg;
}
}
......@@ -55,7 +55,7 @@ public class ConstInlinerVisitor extends AbstractVisitor {
if (parentInsn != null) {
// TODO: speed up expensive operations
BlockNode useBlock = BlockUtils.getBlockByInsn(mth, parentInsn);
if (!BlockUtils.isCleanPathExists(block, useBlock)) {
if (useBlock == null || !BlockUtils.isCleanPathExists(block, useBlock)) {
return false;
}
}
......
......@@ -3,7 +3,6 @@ package jadx.core.dex.visitors;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.nodes.MethodInlineAttr;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
......@@ -34,10 +33,8 @@ public class MethodInlineVisitor extends AbstractVisitor {
// synthetic field getter
BlockNode block = mth.getBasicBlocks().get(1);
InsnNode insn = block.getInstructions().get(0);
InsnNode inl = new InsnNode(InsnType.ARGS, 1);
// set arg from 'return' instruction
inl.addArg(insn.getArg(0));
addInlineAttr(mth, inl);
addInlineAttr(mth, InsnNode.wrapArg(insn.getArg(0)));
} else {
// synthetic field setter or method invoke
if (firstBlock.getInstructions().size() == 1) {
......
......@@ -132,7 +132,6 @@ public class PrepareForCodeGen extends AbstractVisitor {
RegisterArg res = arith.getResult();
InsnArg arg = arith.getArg(0);
boolean replace = false;
if (res.equals(arg)) {
replace = true;
} else if (arg.isRegister()) {
......@@ -140,9 +139,7 @@ public class PrepareForCodeGen extends AbstractVisitor {
replace = res.equalRegisterAndType(regArg);
}
if (replace) {
ArithNode newArith = new ArithNode(arith.getOp(), res, arith.getArg(1));
InsnArg.updateParentInsn(arith, newArith);
list.set(i, newArith);
arith.add(AFlag.ARITH_ONEARG);
}
}
}
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.Compare;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.regions.loops.ForEachLoop;
import jadx.core.dex.regions.loops.IndexLoop;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.regions.loops.LoopType;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InstructionRemover;
import jadx.core.utils.RegionUtils;
import java.util.List;
......@@ -43,8 +56,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
if (condition == null) {
return;
}
List<RegisterArg> args = condition.getRegisterArgs();
if (checkForIndexedLoop(mth, loopRegion, args)) {
if (checkForIndexedLoop(mth, loopRegion, condition)) {
return;
}
}
......@@ -52,7 +64,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
/**
* Check for indexed loop.
*/
private static boolean checkForIndexedLoop(MethodNode mth, LoopRegion loopRegion, List<RegisterArg> condArgs) {
private static boolean checkForIndexedLoop(MethodNode mth, LoopRegion loopRegion, IfCondition condition) {
InsnNode incrInsn = RegionUtils.getLastInsn(loopRegion);
if (incrInsn == null) {
return false;
......@@ -70,6 +82,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
return false;
}
RegisterArg arg = phiInsn.getResult();
List<RegisterArg> condArgs = condition.getRegisterArgs();
if (!condArgs.contains(arg) || arg.getSVar().isUsedInPhi()) {
return false;
}
......@@ -81,12 +94,96 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
if (!usedOnlyInLoop(mth, loopRegion, arg)) {
return false;
}
// all checks passed
initInsn.add(AFlag.SKIP);
incrInsn.add(AFlag.SKIP);
LoopType arrForEach = checkArrayForEach(mth, initInsn, incrInsn, condition);
if (arrForEach != null) {
loopRegion.setType(arrForEach);
return true;
}
loopRegion.setType(new IndexLoop(initInsn, incrInsn));
return true;
}
private static LoopType checkArrayForEach(MethodNode mth, InsnNode initInsn, InsnNode incrInsn, IfCondition condition) {
if (!(incrInsn instanceof ArithNode)) {
return null;
}
ArithNode arithNode = (ArithNode) incrInsn;
if (arithNode.getOp() != ArithOp.ADD) {
return null;
}
InsnArg lit = incrInsn.getArg(1);
if (!lit.isLiteral() || ((LiteralArg) lit).getLiteral() != 1) {
return null;
}
if (initInsn.getType() != InsnType.CONST
|| !initInsn.getArg(0).isLiteral()
|| ((LiteralArg) initInsn.getArg(0)).getLiteral() != 0) {
return null;
}
InsnArg condArg = incrInsn.getArg(0);
if (!condArg.isRegister()) {
return null;
}
SSAVar sVar = ((RegisterArg) condArg).getSVar();
List<RegisterArg> args = sVar.getUseList();
if (args.size() != 3 || args.get(2) != condArg) {
return null;
}
condArg = args.get(0);
RegisterArg arrIndex = args.get(1);
InsnNode arrGetInsn = arrIndex.getParentInsn();
if (arrGetInsn == null || arrGetInsn.getType() != InsnType.AGET) {
return null;
}
if (!condition.isCompare()) {
return null;
}
Compare compare = condition.getCompare();
if (compare.getOp() != IfOp.LT || compare.getA() != condArg) {
return null;
}
InsnNode len;
InsnArg bCondArg = compare.getB();
if (bCondArg.isInsnWrap()) {
len = ((InsnWrapArg) bCondArg).getWrapInsn();
} else if (bCondArg.isRegister()) {
len = ((RegisterArg) bCondArg).getAssignInsn();
} else {
return null;
}
if (len == null || len.getType() != InsnType.ARRAY_LENGTH) {
return null;
}
InsnArg arrayArg = len.getArg(0);
if (!arrayArg.equals(arrGetInsn.getArg(0))) {
return null;
}
// array for each loop confirmed
len.add(AFlag.SKIP);
arrGetInsn.add(AFlag.SKIP);
InstructionRemover.unbindInsn(mth, len);
// inline array variable
CodeShrinker.shrinkMethod(mth);
RegisterArg iterVar = arrGetInsn.getResult();
if (arrGetInsn.contains(AFlag.WRAPPED)) {
InsnArg wrapArg = BlockUtils.searchWrappedInsnParent(mth, arrGetInsn);
if (wrapArg != null) {
wrapArg.getParentInsn().replaceArg(wrapArg, iterVar);
} else {
LOG.debug(" Wrapped insn not found: {}, mth: {}", arrGetInsn, mth);
}
}
return new ForEachLoop(iterVar, len.getArg(0));
}
private static boolean usedOnlyInLoop(MethodNode mth, LoopRegion loopRegion, RegisterArg arg) {
List<RegisterArg> useList = arg.getSVar().getUseList();
for (RegisterArg useArg : useList) {
......
......@@ -78,6 +78,10 @@ public class ProcessTryCatchRegions extends AbstractRegionVisitor {
}
}
}
if (bs == null) {
LOG.debug(" Can't build try/catch dominators bitset, tb: {}, mth: {} ", tb, mth);
continue;
}
// intersect to get dominator of dominators
List<BlockNode> domBlocks = BlockUtils.bitSetToBlocks(mth, bs);
......
......@@ -166,7 +166,11 @@ public class TernaryMod {
if (!arg.isRegister()) {
continue;
}
int sourceLine = ((RegisterArg) arg).getAssignInsn().getSourceLine();
InsnNode assignInsn = ((RegisterArg) arg).getAssignInsn();
if (assignInsn == null) {
continue;
}
int sourceLine = assignInsn.getSourceLine();
if (sourceLine != 0) {
Integer count = map.get(sourceLine);
if (count != null) {
......
......@@ -130,7 +130,7 @@ public class BlockUtils {
private static BlockNode getBlockByWrappedInsn(MethodNode mth, InsnNode insn) {
for (BlockNode bn : mth.getBasicBlocks()) {
for (InsnNode bi : bn.getInstructions()) {
if (bi == insn || foundWrappedInsn(bi, insn)) {
if (bi == insn || foundWrappedInsn(bi, insn) != null) {
return bn;
}
}
......@@ -138,16 +138,35 @@ public class BlockUtils {
return null;
}
private static boolean foundWrappedInsn(InsnNode container, InsnNode insn) {
public static InsnArg searchWrappedInsnParent(MethodNode mth, InsnNode insn) {
if (!insn.contains(AFlag.WRAPPED)) {
return null;
}
for (BlockNode bn : mth.getBasicBlocks()) {
for (InsnNode bi : bn.getInstructions()) {
InsnArg res = foundWrappedInsn(bi, insn);
if (res != null) {
return res;
}
}
}
return null;
}
private static InsnArg foundWrappedInsn(InsnNode container, InsnNode insn) {
for (InsnArg arg : container.getArguments()) {
if (arg.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
if (wrapInsn == insn || foundWrappedInsn(wrapInsn, insn)) {
return true;
if (wrapInsn == insn) {
return arg;
}
InsnArg res = foundWrappedInsn(wrapInsn, insn);
if (res != null) {
return res;
}
}
return false;
}
return null;
}
public static BitSet blocksToBitSet(MethodNode mth, List<BlockNode> blocks) {
......
......@@ -14,7 +14,7 @@ public class TestInline2 extends InternalJadxTest {
public int test() throws InterruptedException {
int[] a = new int[]{1, 2, 4, 6, 8};
int b = 0;
for (int i = 0; i < a.length; i++) {
for (int i = 0; i < a.length; i+=2) {
b += a[i];
}
for (long i = b; i > 0; i--) {
......@@ -30,7 +30,8 @@ public class TestInline2 extends InternalJadxTest {
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("for (int i = 0; i < a.length; i++) {"));
assertThat(code, containsOne("int[] a = new int[]{1, 2, 4, 6, 8};"));
assertThat(code, containsOne("for (int i = 0; i < a.length; i += 2) {"));
assertThat(code, containsOne("for (long i2 = (long) b; i2 > 0; i2--) {"));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsLines;
import static org.junit.Assert.assertThat;
public class TestArrayForEach extends InternalJadxTest {
public static class TestCls {
private int test(int[] a) {
int sum = 0;
for (int n : a) {
sum += n;
}
return sum;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsLines(2,
"int sum = 0;",
"for (int n : a) {",
indent(1) + "sum += n;",
"}",
"return sum;"
));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsLines;
import static org.junit.Assert.assertThat;
public class TestArrayForEach2 extends InternalJadxTest {
public static class TestCls {
private void test(String str) {
for (String s : str.split("\n")) {
String t = s.trim();
if (t.length() > 0) {
System.out.println(t);
}
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsLines(2,
"for (String s : str.split(\"\\n\")) {",
indent(1) + "String t = s.trim();",
indent(1) + "if (t.length() > 0) {",
indent(2) + "System.out.println(t);",
indent(1) + "}",
"}"
));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestArrayForEachNegative extends InternalJadxTest {
public static class TestCls {
private int test(int[] a, int[] b) {
int sum = 0;
for (int i = 0; i < a.length; i += 2) {
sum += a[i];
}
for (int i = 1; i < a.length; i++) {
sum += a[i];
}
for (int i = 0; i < a.length; i--) {
sum += a[i];
}
for (int i = 0; i <= a.length; i++) {
sum += a[i];
}
for (int i = 0; i + 1 < a.length; i++) {
sum += a[i];
}
for (int i = 0; i < a.length; i++) {
sum += a[i - 1];
}
for (int i = 0; i < b.length; i++) {
sum += a[i];
}
int j = 0;
for (int i = 0; i < a.length; j++) {
sum += a[j];
}
return sum;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, not(containsString(":")));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment