Commit 0c2784bb authored by Skylot's avatar Skylot

refactor: inline fields in arithmetic operations

parent c555cd08
...@@ -51,6 +51,7 @@ public class ArithNode extends InsnNode { ...@@ -51,6 +51,7 @@ public class ArithNode extends InsnNode {
addArg(b); addArg(b);
} }
// TODO: remove result for one arg insn, this will simplify processing and allow to remove FieldArg
public ArithNode(ArithOp op, RegisterArg res, InsnArg a) { public ArithNode(ArithOp op, RegisterArg res, InsnArg a) {
this(op, res, res, a); this(op, res, res, a);
add(AFlag.ARITH_ONEARG); add(AFlag.ARITH_ONEARG);
......
...@@ -4,6 +4,7 @@ import java.util.ArrayList; ...@@ -4,6 +4,7 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -39,6 +40,7 @@ import jadx.core.dex.regions.conditions.IfCondition; ...@@ -39,6 +40,7 @@ import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList; import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover; import jadx.core.utils.InsnRemover;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class SimplifyVisitor extends AbstractVisitor { public class SimplifyVisitor extends AbstractVisitor {
...@@ -59,21 +61,16 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -59,21 +61,16 @@ public class SimplifyVisitor extends AbstractVisitor {
for (int i = 0; i < list.size(); i++) { for (int i = 0; i < list.size(); i++) {
InsnNode insn = list.get(i); InsnNode insn = list.get(i);
int insnCount = list.size(); int insnCount = list.size();
InsnNode modInsn = simplifyInsn(mth, insn); InsnNode modInsn = simplifyInsn(mth, block, insn);
if (modInsn != null) { if (modInsn != null) {
if (i != 0 && modInsn.contains(AFlag.ARITH_ONEARG)) {
InsnNode mergedNode = simplifyOneArgConsecutive(
list.get(i - 1), list.get(i), (ArithNode) modInsn);
if (mergedNode != null) {
list.remove(i - 1);
modInsn = mergedNode;
i--;
}
}
if (i < list.size() && list.get(i) == insn) { if (i < list.size() && list.get(i) == insn) {
list.set(i, modInsn); list.set(i, modInsn);
} else {
int idx = InsnList.getIndex(list, insn);
if (idx == -1) {
throw new JadxRuntimeException("Failed to replace insn");
} }
list.set(idx, modInsn);
} }
if (list.size() < insnCount) { if (list.size() < insnCount) {
// some insns removed => restart block processing // some insns removed => restart block processing
...@@ -82,27 +79,15 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -82,27 +79,15 @@ public class SimplifyVisitor extends AbstractVisitor {
} }
} }
} }
private static InsnNode simplifyOneArgConsecutive(InsnNode insn1, InsnNode insn2, ArithNode modInsn) {
if (insn1.getType() == InsnType.IGET
&& insn2.getType() == InsnType.IPUT
&& insn1.getResult().getSVar().getUseCount() == 2
&& insn2.getArg(1).equals(insn1.getResult())) {
FieldInfo field = (FieldInfo) ((IndexInsnNode) insn2).getIndex();
FieldArg fArg = new FieldArg(field, new InsnWrapArg(insn1));
return new ArithNode(modInsn.getOp(), fArg, modInsn.getArg(1));
}
return null;
} }
private static InsnNode simplifyInsn(MethodNode mth, InsnNode insn) { private static InsnNode simplifyInsn(MethodNode mth, BlockNode block, InsnNode insn) {
if (insn.contains(AFlag.DONT_GENERATE)) { if (insn.contains(AFlag.DONT_GENERATE)) {
return null; return null;
} }
for (InsnArg arg : insn.getArguments()) { for (InsnArg arg : insn.getArguments()) {
if (arg.isInsnWrap()) { if (arg.isInsnWrap()) {
InsnNode ni = simplifyInsn(mth, ((InsnWrapArg) arg).getWrapInsn()); InsnNode ni = simplifyInsn(mth, block, ((InsnWrapArg) arg).getWrapInsn());
if (ni != null) { if (ni != null) {
arg.wrapInstruction(ni); arg.wrapInstruction(ni);
} }
...@@ -124,7 +109,7 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -124,7 +109,7 @@ public class SimplifyVisitor extends AbstractVisitor {
case IPUT: case IPUT:
case SPUT: case SPUT:
return convertFieldArith(mth, insn); return convertFieldArith(mth, block, insn);
case CHECK_CAST: case CHECK_CAST:
return processCast(mth, insn); return processCast(mth, insn);
...@@ -421,7 +406,7 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -421,7 +406,7 @@ public class SimplifyVisitor extends AbstractVisitor {
* Convert field arith operation to arith instruction * Convert field arith operation to arith instruction
* (IPUT = ARITH (IGET, lit) -> ARITH (fieldArg <op>= lit)) * (IPUT = ARITH (IGET, lit) -> ARITH (fieldArg <op>= lit))
*/ */
private static InsnNode convertFieldArith(MethodNode mth, InsnNode insn) { private static ArithNode convertFieldArith(MethodNode mth, BlockNode block, InsnNode insn) {
InsnArg arg = insn.getArg(0); InsnArg arg = insn.getArg(0);
if (!arg.isInsnWrap()) { if (!arg.isInsnWrap()) {
return null; return null;
...@@ -451,6 +436,7 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -451,6 +436,7 @@ public class SimplifyVisitor extends AbstractVisitor {
return null; return null;
} }
} }
reg = inlineFieldGet(reg, block, get, insn);
FieldArg fArg = new FieldArg(field, reg); FieldArg fArg = new FieldArg(field, reg);
if (wrapType == InsnType.ARITH) { if (wrapType == InsnType.ARITH) {
ArithNode ar = (ArithNode) wrap; ArithNode ar = (ArithNode) wrap;
...@@ -467,4 +453,24 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -467,4 +453,24 @@ public class SimplifyVisitor extends AbstractVisitor {
} }
return null; return null;
} }
private static InsnArg inlineFieldGet(@Nullable InsnArg arg, BlockNode block, InsnNode get, InsnNode insn) {
if (arg == null || !arg.isRegister()) {
return arg;
}
RegisterArg reg = (RegisterArg) arg;
InsnNode assignInsn = reg.getAssignInsn();
SSAVar ssaVar = reg.getSVar();
if (ssaVar.getUseCount() == 2 && !ssaVar.isUsedInPhi() && assignInsn != null) {
List<RegisterArg> useList = ssaVar.getUseList();
if (useList.get(0).getParentInsn() == get && useList.get(1).getParentInsn() == insn) {
InsnType assignInsnType = assignInsn.getType();
if (assignInsnType == InsnType.IGET || assignInsnType == InsnType.SGET) {
InsnList.remove(block, assignInsn);
return InsnArg.wrapArg(assignInsn);
}
}
}
return arg;
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment