Commit 0c2784bb authored by Skylot's avatar Skylot

refactor: inline fields in arithmetic operations

parent c555cd08
......@@ -51,6 +51,7 @@ public class ArithNode extends InsnNode {
addArg(b);
}
// TODO: remove result for one arg insn, this will simplify processing and allow to remove FieldArg
public ArithNode(ArithOp op, RegisterArg res, InsnArg a) {
this(op, res, res, a);
add(AFlag.ARITH_ONEARG);
......
......@@ -4,6 +4,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -39,6 +40,7 @@ import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class SimplifyVisitor extends AbstractVisitor {
......@@ -59,50 +61,33 @@ public class SimplifyVisitor extends AbstractVisitor {
for (int i = 0; i < list.size(); i++) {
InsnNode insn = list.get(i);
int insnCount = list.size();
InsnNode modInsn = simplifyInsn(mth, insn);
InsnNode modInsn = simplifyInsn(mth, block, insn);
if (modInsn != null) {
if (i != 0 && modInsn.contains(AFlag.ARITH_ONEARG)) {
InsnNode mergedNode = simplifyOneArgConsecutive(
list.get(i - 1), list.get(i), (ArithNode) modInsn);
if (mergedNode != null) {
list.remove(i - 1);
modInsn = mergedNode;
i--;
}
}
if (i < list.size() && list.get(i) == insn) {
list.set(i, modInsn);
} else {
int idx = InsnList.getIndex(list, insn);
if (idx == -1) {
throw new JadxRuntimeException("Failed to replace insn");
}
list.set(idx, modInsn);
}
if (list.size() < insnCount) {
// some insns removed => restart block processing
simplifyBlock(mth, block);
return;
}
}
if (list.size() < insnCount) {
// some insns removed => restart block processing
simplifyBlock(mth, block);
return;
}
}
}
private static InsnNode simplifyOneArgConsecutive(InsnNode insn1, InsnNode insn2, ArithNode modInsn) {
if (insn1.getType() == InsnType.IGET
&& insn2.getType() == InsnType.IPUT
&& insn1.getResult().getSVar().getUseCount() == 2
&& insn2.getArg(1).equals(insn1.getResult())) {
FieldInfo field = (FieldInfo) ((IndexInsnNode) insn2).getIndex();
FieldArg fArg = new FieldArg(field, new InsnWrapArg(insn1));
return new ArithNode(modInsn.getOp(), fArg, modInsn.getArg(1));
}
return null;
}
private static InsnNode simplifyInsn(MethodNode mth, InsnNode insn) {
private static InsnNode simplifyInsn(MethodNode mth, BlockNode block, InsnNode insn) {
if (insn.contains(AFlag.DONT_GENERATE)) {
return null;
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isInsnWrap()) {
InsnNode ni = simplifyInsn(mth, ((InsnWrapArg) arg).getWrapInsn());
InsnNode ni = simplifyInsn(mth, block, ((InsnWrapArg) arg).getWrapInsn());
if (ni != null) {
arg.wrapInstruction(ni);
}
......@@ -124,7 +109,7 @@ public class SimplifyVisitor extends AbstractVisitor {
case IPUT:
case SPUT:
return convertFieldArith(mth, insn);
return convertFieldArith(mth, block, insn);
case CHECK_CAST:
return processCast(mth, insn);
......@@ -421,7 +406,7 @@ public class SimplifyVisitor extends AbstractVisitor {
* Convert field arith operation to arith instruction
* (IPUT = ARITH (IGET, lit) -> ARITH (fieldArg <op>= lit))
*/
private static InsnNode convertFieldArith(MethodNode mth, InsnNode insn) {
private static ArithNode convertFieldArith(MethodNode mth, BlockNode block, InsnNode insn) {
InsnArg arg = insn.getArg(0);
if (!arg.isInsnWrap()) {
return null;
......@@ -451,6 +436,7 @@ public class SimplifyVisitor extends AbstractVisitor {
return null;
}
}
reg = inlineFieldGet(reg, block, get, insn);
FieldArg fArg = new FieldArg(field, reg);
if (wrapType == InsnType.ARITH) {
ArithNode ar = (ArithNode) wrap;
......@@ -467,4 +453,24 @@ public class SimplifyVisitor extends AbstractVisitor {
}
return null;
}
private static InsnArg inlineFieldGet(@Nullable InsnArg arg, BlockNode block, InsnNode get, InsnNode insn) {
if (arg == null || !arg.isRegister()) {
return arg;
}
RegisterArg reg = (RegisterArg) arg;
InsnNode assignInsn = reg.getAssignInsn();
SSAVar ssaVar = reg.getSVar();
if (ssaVar.getUseCount() == 2 && !ssaVar.isUsedInPhi() && assignInsn != null) {
List<RegisterArg> useList = ssaVar.getUseList();
if (useList.get(0).getParentInsn() == get && useList.get(1).getParentInsn() == insn) {
InsnType assignInsnType = assignInsn.getType();
if (assignInsnType == InsnType.IGET || assignInsnType == InsnType.SGET) {
InsnList.remove(block, assignInsn);
return InsnArg.wrapArg(assignInsn);
}
}
}
return arg;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment