Commit 010ae99c authored by Skylot's avatar Skylot

core: restore simple for-each loop over iterable object

parent a4632d6e
......@@ -20,7 +20,7 @@ import jadx.core.dex.regions.SynchronizedRegion;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.ForEachLoop;
import jadx.core.dex.regions.loops.IndexLoop;
import jadx.core.dex.regions.loops.ForLoop;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.regions.loops.LoopType;
import jadx.core.dex.trycatch.CatchAttr;
......@@ -175,14 +175,14 @@ public class RegionGen extends InsnGen {
ConditionGen conditionGen = new ConditionGen(this);
LoopType type = region.getType();
if (type != null) {
if (type instanceof IndexLoop) {
IndexLoop indexLoop = (IndexLoop) type;
if (type instanceof ForLoop) {
ForLoop forLoop = (ForLoop) type;
code.startLine("for (");
makeInsn(indexLoop.getInitInsn(), code, Flags.INLINE);
makeInsn(forLoop.getInitInsn(), code, Flags.INLINE);
code.add("; ");
conditionGen.add(code, condition);
code.add("; ");
makeInsn(indexLoop.getIncrInsn(), code, Flags.INLINE);
makeInsn(forLoop.getIncrInsn(), code, Flags.INLINE);
code.add(") {");
makeRegionIndent(code, region.getBody());
code.startLine('}');
......
......@@ -543,6 +543,16 @@ public abstract class ArgType {
return true;
}
public static boolean isInstanceOf(ArgType type, ArgType of) {
if (type.equals(of)) {
return true;
}
if (!type.isObject() || !of.isObject()) {
return false;
}
return clsp.isImplements(type.getObject(), of.getObject());
}
public static ArgType parse(String type) {
char f = type.charAt(0);
switch (f) {
......
......@@ -9,7 +9,9 @@ import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.Utils;
public class TernaryInsn extends InsnNode {
import java.util.List;
public final class TernaryInsn extends InsnNode {
private IfCondition condition;
......@@ -52,6 +54,12 @@ public class TernaryInsn extends InsnNode {
}
@Override
public void getRegisterArgs(List<RegisterArg> list) {
super.getRegisterArgs(list);
list.addAll(condition.getRegisterArgs());
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
......
......@@ -3,7 +3,6 @@ package jadx.core.dex.regions.conditions;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
......@@ -209,14 +208,7 @@ public final class IfCondition {
public List<RegisterArg> getRegisterArgs() {
List<RegisterArg> list = new LinkedList<RegisterArg>();
if (mode == Mode.COMPARE) {
InsnArg a = compare.getA();
if (a.isRegister()) {
list.add((RegisterArg) a);
}
InsnArg b = compare.getB();
if (b.isRegister()) {
list.add((RegisterArg) b);
}
compare.getInsn().getRegisterArgs(list);
} else {
for (IfCondition arg : args) {
list.addAll(arg.getRegisterArgs());
......
......@@ -3,7 +3,7 @@ package jadx.core.dex.regions.loops;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
public class ForEachLoop extends LoopType {
public final class ForEachLoop extends LoopType {
private final RegisterArg varArg;
private final InsnArg iterableArg;
......
......@@ -2,12 +2,12 @@ package jadx.core.dex.regions.loops;
import jadx.core.dex.nodes.InsnNode;
public class IndexLoop extends LoopType {
public final class ForLoop extends LoopType {
private final InsnNode initInsn;
private final InsnNode incrInsn;
public IndexLoop(InsnNode initInsn, InsnNode incrInsn) {
public ForLoop(InsnNode initInsn, InsnNode incrInsn) {
this.initInsn = initInsn;
this.incrInsn = incrInsn;
}
......
......@@ -123,23 +123,20 @@ public class PrepareForCodeGen extends AbstractVisitor {
*/
private static void modifyArith(BlockNode block) {
List<InsnNode> list = block.getInstructions();
for (int i = 0; i < list.size(); i++) {
InsnNode insn = list.get(i);
if (insn.getType() != InsnType.ARITH) {
continue;
}
ArithNode arith = (ArithNode) insn;
RegisterArg res = arith.getResult();
InsnArg arg = arith.getArg(0);
boolean replace = false;
if (res.equals(arg)) {
replace = true;
} else if (arg.isRegister()) {
RegisterArg regArg = (RegisterArg) arg;
replace = res.equalRegisterAndType(regArg);
}
if (replace) {
arith.add(AFlag.ARITH_ONEARG);
for (InsnNode insn : list) {
if (insn.getType() == InsnType.ARITH) {
RegisterArg res = insn.getResult();
InsnArg arg = insn.getArg(0);
boolean replace = false;
if (res.equals(arg)) {
replace = true;
} else if (arg.isRegister()) {
RegisterArg regArg = (RegisterArg) arg;
replace = res.equalRegisterAndType(regArg);
}
if (replace) {
insn.add(AFlag.ARITH_ONEARG);
}
}
}
}
......
package jadx.core.dex.visitors.regions;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.InvokeType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
......@@ -19,7 +23,7 @@ import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.Compare;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.regions.loops.ForEachLoop;
import jadx.core.dex.regions.loops.IndexLoop;
import jadx.core.dex.regions.loops.ForLoop;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.regions.loops.LoopType;
import jadx.core.dex.visitors.AbstractVisitor;
......@@ -28,6 +32,7 @@ import jadx.core.utils.BlockUtils;
import jadx.core.utils.InstructionRemover;
import jadx.core.utils.RegionUtils;
import java.util.LinkedList;
import java.util.List;
import org.slf4j.Logger;
......@@ -59,6 +64,9 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
if (checkForIndexedLoop(mth, loopRegion, condition)) {
return;
}
if (checkIterableForEach(mth, loopRegion, condition)) {
return;
}
}
/**
......@@ -103,7 +111,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
loopRegion.setType(arrForEach);
return true;
}
loopRegion.setType(new IndexLoop(initInsn, incrInsn));
loopRegion.setType(new ForLoop(initInsn, incrInsn));
return true;
}
......@@ -184,6 +192,107 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
return new ForEachLoop(iterVar, len.getArg(0));
}
private static boolean checkIterableForEach(MethodNode mth, LoopRegion loopRegion, IfCondition condition) {
List<RegisterArg> condArgs = condition.getRegisterArgs();
if (condArgs.size() != 1) {
return false;
}
RegisterArg iteratorArg = condArgs.get(0);
SSAVar sVar = iteratorArg.getSVar();
if (sVar == null || sVar.isUsedInPhi()) {
return false;
}
List<RegisterArg> useList = sVar.getUseList();
InsnNode assignInsn = iteratorArg.getAssignInsn();
if (useList.size() != 2
|| assignInsn == null
|| !checkInvoke(assignInsn, null, "iterator()Ljava/util/Iterator;", 0)) {
return false;
}
InsnArg iterableArg = assignInsn.getArg(0);
InsnNode hasNextCall = useList.get(0).getParentInsn();
InsnNode nextCall = useList.get(1).getParentInsn();
if (!checkInvoke(hasNextCall, "java.util.Iterator", "hasNext()Z", 0)
|| !checkInvoke(nextCall, "java.util.Iterator", "next()Ljava/lang/Object;", 0)) {
return false;
}
List<InsnNode> toSkip = new LinkedList<InsnNode>();
RegisterArg iterVar = nextCall.getResult();
if (nextCall.contains(AFlag.WRAPPED)) {
InsnArg wrapArg = BlockUtils.searchWrappedInsnParent(mth, nextCall);
if (wrapArg != null) {
InsnNode parentInsn = wrapArg.getParentInsn();
if (parentInsn.getType() != InsnType.CHECK_CAST) {
parentInsn.replaceArg(wrapArg, iterVar);
} else {
iterVar = parentInsn.getResult();
InsnArg castArg = BlockUtils.searchWrappedInsnParent(mth, parentInsn);
if (castArg != null) {
castArg.getParentInsn().replaceArg(castArg, iterVar);
} else {
// cast not inlined
toSkip.add(parentInsn);
}
}
} else {
LOG.warn(" Wrapped insn not found: {}, mth: {}", nextCall, mth);
return false;
}
} else {
toSkip.add(nextCall);
}
if (!fixIterableType(iterableArg, iterVar)) {
return false;
}
assignInsn.add(AFlag.SKIP);
for (InsnNode insnNode : toSkip) {
insnNode.add(AFlag.SKIP);
}
loopRegion.setType(new ForEachLoop(iterVar, iterableArg));
return true;
}
private static boolean fixIterableType(InsnArg iterableArg, RegisterArg iterVar) {
ArgType type = iterableArg.getType();
if (type.isGeneric()) {
ArgType[] genericTypes = type.getGenericTypes();
if (genericTypes != null && genericTypes.length == 1) {
ArgType gType = genericTypes[0];
if (ArgType.isInstanceOf(gType, iterVar.getType())) {
return true;
} else {
LOG.warn("Generic type differs: {} and {}", type, iterVar.getType());
}
}
} else {
if (!iterableArg.isRegister()) {
return true;
}
// TODO: add checks
type = ArgType.generic(type.getObject(), new ArgType[]{iterVar.getType()});
iterableArg.setType(type);
return true;
}
return false;
}
/**
* Check if instruction is a interface invoke with corresponding parameters.
*/
private static boolean checkInvoke(InsnNode insn, String declClsFullName, String mthId, int argsCount) {
if (insn.getType() == InsnType.INVOKE) {
InvokeNode inv = (InvokeNode) insn;
MethodInfo callMth = inv.getCallMth();
if (callMth.getArgsCount() == argsCount
&& callMth.getShortId().equals(mthId)
&& inv.getInvokeType() == InvokeType.INTERFACE) {
return declClsFullName == null || callMth.getDeclClass().getFullName().equals(declClsFullName);
}
}
return false;
}
private static boolean usedOnlyInLoop(MethodNode mth, LoopRegion loopRegion, RegisterArg arg) {
List<RegisterArg> useList = arg.getSVar().getUseList();
for (RegisterArg useArg : useList) {
......
......@@ -2,12 +2,15 @@ package jadx.core.utils;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
......@@ -138,6 +141,14 @@ public class BlockUtils {
return null;
}
public static InsnNode searchInsnParent(MethodNode mth, InsnNode insn) {
InsnArg insnArg = searchWrappedInsnParent(mth, insn);
if (insnArg == null) {
return null;
}
return insnArg.getParentInsn();
}
public static InsnArg searchWrappedInsnParent(MethodNode mth, InsnNode insn) {
if (!insn.contains(AFlag.WRAPPED)) {
return null;
......@@ -166,6 +177,23 @@ public class BlockUtils {
}
}
}
if (container instanceof TernaryInsn) {
return foundWrappedInsnInCondition(((TernaryInsn) container).getCondition(), insn);
}
return null;
}
private static InsnArg foundWrappedInsnInCondition(IfCondition cond, InsnNode insn) {
if (cond.isCompare()) {
IfNode cmpInsn = cond.getCompare().getInsn();
return foundWrappedInsn(cmpInsn, insn);
}
for (IfCondition nestedCond : cond.getArgs()) {
InsnArg res = foundWrappedInsnInCondition(nestedCond, insn);
if (res != null) {
return res;
}
}
return null;
}
......
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsLines;
import static org.junit.Assert.assertThat;
public class TestIterableForEach extends InternalJadxTest {
public static class TestCls {
private String test(Iterable<String> a) {
StringBuilder sb = new StringBuilder();
for (String s : a) {
sb.append(s);
}
return sb.toString();
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsLines(2,
"StringBuilder sb = new StringBuilder();",
"for (String s : a) {",
indent(1) + "sb.append(s);",
"}",
"return sb.toString();"
));
}
}
package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import java.io.IOException;
import java.util.List;
import org.junit.Test;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestIterableForEach2 extends InternalJadxTest {
public static class TestCls {
public static String test(final Service service) throws IOException {
for (Authorization auth : service.getAuthorizations()) {
if (isValid(auth)) {
return auth.getToken();
}
}
return null;
}
private static boolean isValid(Authorization auth) {
return false;
}
private static class Service {
public List<Authorization> getAuthorizations() {
return null;
}
}
private static class Authorization {
public String getToken() {
return "";
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsOne("for (Authorization auth : service.getAuthorizations()) {"));
assertThat(code, containsOne("if (isValid(auth)) {"));
assertThat(code, containsOne("return auth.getToken();"));
}
}
......@@ -3,12 +3,11 @@ package jadx.tests.internal.loops;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import java.util.Iterator;
import java.util.List;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.junit.Assert.assertThat;
public class TestNestedLoops extends InternalJadxTest {
......@@ -16,12 +15,8 @@ public class TestNestedLoops extends InternalJadxTest {
public static class TestCls {
private void test(List<String> l1, List<String> l2) {
Iterator<String> it1 = l1.iterator();
while (it1.hasNext()) {
String s1 = it1.next();
Iterator<String> it2 = l2.iterator();
while (it2.hasNext()) {
String s2 = it2.next();
for (String s1 : l1) {
for (String s2 : l2) {
if (s1.equals(s2)) {
if (s1.length() == 5) {
l2.add(s1);
......@@ -43,9 +38,10 @@ public class TestNestedLoops extends InternalJadxTest {
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsString("while (it1.hasNext()) {"));
assertThat(code, containsString("while (it2.hasNext()) {"));
assertThat(code, containsString("if (s1.equals(s2)) {"));
assertThat(code, containsString("l2.add(s1);"));
assertThat(code, containsOne("for (String s1 : l1) {"));
assertThat(code, containsOne("for (String s2 : l2) {"));
assertThat(code, containsOne("if (s1.equals(s2)) {"));
assertThat(code, containsOne("l2.add(s1);"));
assertThat(code, containsOne("l1.remove(s2);"));
}
}
......@@ -5,12 +5,12 @@ import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.visitors.DepthTraversal;
import jadx.core.dex.visitors.IDexTreeVisitor;
import java.util.Iterator;
import java.util.List;
import org.junit.Test;
import org.slf4j.Logger;
import static jadx.tests.utils.JadxMatchers.containsOne;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
......@@ -25,9 +25,8 @@ public class TestVariablesDefinitions extends InternalJadxTest {
public void run() {
try {
cls.load();
Iterator<IDexTreeVisitor> iterator = passes.iterator();
while (iterator.hasNext()) {
DepthTraversal.visit(iterator.next(), cls);
for (IDexTreeVisitor pass : this.passes) {
DepthTraversal.visit(pass, cls);
}
} catch (Exception e) {
LOG.error("Decode exception: " + cls, e);
......@@ -41,8 +40,7 @@ public class TestVariablesDefinitions extends InternalJadxTest {
String code = cls.getCode().toString();
System.out.println(code);
// 'iterator' variable must be declared inside 'try' block
assertThat(code, containsString(indent(3) + "Iterator<IDexTreeVisitor> iterator = "));
assertThat(code, containsOne(indent(3) + "for (IDexTreeVisitor pass : this.passes) {"));
assertThat(code, not(containsString("iterator;")));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment