Commit f9f840fb authored by Skylot's avatar Skylot

refactor: remove redundant FieldArg and change arith one arg insn

parent 8e8a2faa
......@@ -36,7 +36,6 @@ import jadx.core.dex.instructions.NewArrayNode;
import jadx.core.dex.instructions.SwitchNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.CodeVar;
import jadx.core.dex.instructions.args.FieldArg;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
......@@ -105,13 +104,6 @@ public class InsnGen {
makeInsn(((InsnWrapArg) arg).getWrapInsn(), code, flag);
} else if (arg.isNamed()) {
code.add(((Named) arg).getName());
} else if (arg.isField()) {
FieldArg f = (FieldArg) arg;
if (f.isStatic()) {
staticField(code, f.getField());
} else {
instanceField(code, f.getField(), f.getInstanceArg());
}
} else {
throw new CodegenException("Unknown arg type " + arg);
}
......@@ -174,6 +166,7 @@ public class InsnGen {
public static void makeStaticFieldAccess(CodeWriter code, FieldInfo field, ClassGen clsGen) {
ClassInfo declClass = field.getDeclClass();
// TODO
boolean fieldFromThisClass = clsGen.getClassNode().getClassInfo().equals(declClass);
if (!fieldFromThisClass) {
// Android specific resources class handler
......@@ -231,10 +224,10 @@ public class InsnGen {
code.add("// ");
}
}
if (insn.getResult() != null) {
SSAVar var = insn.getResult().getSVar();
if ((var == null || var.getUseCount() != 0 || insn.getType() != InsnType.CONSTRUCTOR)
&& !insn.contains(AFlag.ARITH_ONEARG)) {
RegisterArg resArg = insn.getResult();
if (resArg != null) {
SSAVar var = resArg.getSVar();
if (var == null || var.getUseCount() != 0 || insn.getType() != InsnType.CONSTRUCTOR) {
assignVar(code, insn);
code.add(" = ");
}
......@@ -981,19 +974,22 @@ public class InsnGen {
private void makeArithOneArg(ArithNode insn, CodeWriter code) throws CodegenException {
ArithOp op = insn.getOp();
InsnArg resArg = insn.getArg(0);
InsnArg arg = insn.getArg(1);
// "++" or "--"
if (arg.isLiteral() && (op == ArithOp.ADD || op == ArithOp.SUB)) {
LiteralArg lit = (LiteralArg) arg;
if (lit.isInteger() && lit.getLiteral() == 1) {
assignVar(code, insn);
if (lit.getLiteral() == 1 && lit.isInteger()) {
addArg(code, resArg, false);
String opSymbol = op.getSymbol();
code.add(opSymbol).add(opSymbol);
return;
}
}
// +=, -= ...
assignVar(code, insn);
// +=, -=, ...
addArg(code, resArg, false);
code.add(' ').add(op.getSymbol()).add("= ");
addArg(code, arg, false);
}
......
......@@ -51,10 +51,16 @@ public class ArithNode extends InsnNode {
addArg(b);
}
// TODO: remove result for one arg insn, this will simplify processing and allow to remove FieldArg
public ArithNode(ArithOp op, RegisterArg res, InsnArg a) {
this(op, res, res, a);
add(AFlag.ARITH_ONEARG);
/**
* Create one argument arithmetic instructions (a+=2).
* Result is not set (null).
*
* @param res argument to change
*/
public static ArithNode oneArgOp(ArithOp op, InsnArg res, InsnArg a) {
ArithNode insn = new ArithNode(op, null, res, a);
insn.add(AFlag.ARITH_ONEARG);
return insn;
}
public ArithOp getOp() {
......
package jadx.core.dex.instructions.args;
import java.util.Objects;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.info.FieldInfo;
import jadx.core.utils.exceptions.JadxRuntimeException;
// TODO: don't extend RegisterArg (now used as a result of instruction)
public final class FieldArg extends RegisterArg {
private final FieldInfo field;
// instArg equal 'null' for static fields
@Nullable
private final InsnArg instArg;
public FieldArg(FieldInfo field, @Nullable InsnArg reg) {
super(-1, field.getType());
this.instArg = reg;
this.field = field;
}
public FieldInfo getField() {
return field;
}
public InsnArg getInstanceArg() {
return instArg;
}
public boolean isStatic() {
return instArg == null;
}
@Override
public boolean isField() {
return true;
}
@Override
public boolean isRegister() {
return false;
}
@Override
public ArgType getType() {
return this.field.getType();
}
@Override
public ArgType getInitType() {
return this.field.getType();
}
@Override
public void setType(ArgType newType) {
throw new JadxRuntimeException("Can't set type for FieldArg");
}
@Override
public RegisterArg duplicate() {
return copyCommonParams(new FieldArg(field, instArg));
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof FieldArg) || !super.equals(obj)) {
return false;
}
FieldArg fieldArg = (FieldArg) obj;
if (!field.equals(fieldArg.field)) {
return false;
}
return Objects.equals(instArg, fieldArg.instArg);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + field.hashCode();
result = 31 * result + (instArg != null ? instArg.hashCode() : 0);
return result;
}
@Override
public String toString() {
return "(" + field + ')';
}
}
......@@ -84,10 +84,6 @@ public abstract class InsnArg extends Typed {
return false;
}
public boolean isField() {
return false;
}
@Nullable
public InsnNode getParentInsn() {
return parentInsn;
......
......@@ -206,6 +206,15 @@ public class InsnNode extends LineAttrNode {
}
return true;
}
for (InsnArg arg : getArguments()) {
if (arg.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
if (!wrapInsn.canReorder()) {
return false;
}
}
}
switch (getType()) {
case CONST:
case CONST_STR:
......
......@@ -41,6 +41,7 @@ import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.trycatch.ExcHandlerAttr;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.visitors.regions.variables.ProcessVariables;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.ErrorsCounter;
import jadx.core.utils.InsnRemover;
......@@ -56,17 +57,22 @@ import static jadx.core.utils.BlockUtils.replaceInsn;
@JadxVisitor(
name = "ModVisitor",
desc = "Modify method instructions",
runBefore = CodeShrinkVisitor.class
runBefore = {
CodeShrinkVisitor.class,
ProcessVariables.class
}
)
public class ModVisitor extends AbstractVisitor {
private static final Logger LOG = LoggerFactory.getLogger(ModVisitor.class);
private static final long DOUBLE_TO_BITS = Double.doubleToLongBits(1);
private static final long FLOAT_TO_BITS = Float.floatToIntBits(1);
@Override
public void visit(MethodNode mth) {
if (mth.isNoCode()) {
return;
}
InsnRemover remover = new InsnRemover(mth);
replaceStep(mth, remover);
removeStep(mth, remover);
......@@ -76,9 +82,10 @@ public class ModVisitor extends AbstractVisitor {
ClassNode parentClass = mth.getParentClass();
for (BlockNode block : mth.getBasicBlocks()) {
remover.setBlock(block);
int size = block.getInstructions().size();
List<InsnNode> insnsList = block.getInstructions();
int size = insnsList.size();
for (int i = 0; i < size; i++) {
InsnNode insn = block.getInstructions().get(i);
InsnNode insn = insnsList.get(i);
switch (insn.getType()) {
case CONSTRUCTOR:
processAnonymousConstructor(mth, ((ConstructorInsn) insn));
......@@ -86,33 +93,12 @@ public class ModVisitor extends AbstractVisitor {
case CONST:
case CONST_STR:
case CONST_CLASS: {
FieldNode f;
if (insn.getType() == InsnType.CONST_STR) {
String s = ((ConstStringNode) insn).getString();
f = parentClass.getConstField(s);
} else if (insn.getType() == InsnType.CONST_CLASS) {
ArgType t = ((ConstClassNode) insn).getClsType();
f = parentClass.getConstField(t);
} else {
f = parentClass.getConstFieldByLiteralArg((LiteralArg) insn.getArg(0));
}
if (f != null) {
InsnNode inode = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0);
inode.setResult(insn.getResult());
replaceInsn(block, i, inode);
}
case CONST_CLASS:
replaceConst(parentClass, block, i, insn);
break;
}
case SWITCH:
SwitchNode sn = (SwitchNode) insn;
for (int k = 0; k < sn.getCasesCount(); k++) {
FieldNode f = parentClass.getConstField(sn.getKeys()[k]);
if (f != null) {
sn.getKeys()[k] = f;
}
}
replaceConstKeys(parentClass, (SwitchNode) insn);
break;
case NEW_ARRAY:
......@@ -134,48 +120,15 @@ public class ModVisitor extends AbstractVisitor {
break;
case ARITH:
ArithNode arithNode = (ArithNode) insn;
if (arithNode.getArgsCount() == 2) {
InsnArg litArg = arithNode.getArg(1);
if (litArg.isLiteral()) {
FieldNode f = parentClass.getConstFieldByLiteralArg((LiteralArg) litArg);
if (f != null) {
InsnNode fGet = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0);
insn.replaceArg(litArg, InsnArg.wrapArg(fGet));
}
}
}
processArith(parentClass, (ArithNode) insn);
break;
case CHECK_CAST:
InsnArg castArg = insn.getArg(0);
ArgType castType = (ArgType) ((IndexInsnNode) insn).getIndex();
if (!ArgType.isCastNeeded(mth.dex(), castArg.getType(), castType)
|| isCastDuplicate((IndexInsnNode) insn)) {
InsnNode insnNode = new InsnNode(InsnType.MOVE, 1);
insnNode.setResult(insn.getResult());
insnNode.addArg(castArg);
replaceInsn(block, i, insnNode);
}
removeRedundantCast(mth, block, i, (IndexInsnNode) insn);
break;
case CAST:
// replace boolean to (byte/char/short/long/double/float) cast with ternary
if (insn.getArg(0).getType() == ArgType.BOOLEAN) {
ArgType type = insn.getResult().getType();
if (type.isPrimitive()) {
IfNode ifNode = new IfNode(IfOp.EQ, -1, insn.getArg(0), LiteralArg.TRUE);
IfCondition condition = IfCondition.fromIfNode(ifNode);
InsnArg zero = new LiteralArg(0, type);
InsnArg one = new LiteralArg(
type == ArgType.DOUBLE
? Double.doubleToLongBits(1)
: type == ArgType.FLOAT ? Float.floatToIntBits(1) : 1,
type);
TernaryInsn ternary = new TernaryInsn(condition, insn.getResult(), one, zero);
replaceInsn(block, i, ternary);
}
}
fixPrimitiveCast(block, i, insn);
break;
default:
......@@ -186,6 +139,68 @@ public class ModVisitor extends AbstractVisitor {
}
}
private static void replaceConstKeys(ClassNode parentClass, SwitchNode insn) {
for (int k = 0; k < insn.getCasesCount(); k++) {
FieldNode f = parentClass.getConstField(insn.getKeys()[k]);
if (f != null) {
insn.getKeys()[k] = f;
}
}
}
private static void fixPrimitiveCast(BlockNode block, int i, InsnNode insn) {
// replace boolean to (byte/char/short/long/double/float) cast with ternary
if (insn.getArg(0).getType() == ArgType.BOOLEAN) {
ArgType type = insn.getResult().getType();
if (type.isPrimitive()) {
IfNode ifNode = new IfNode(IfOp.EQ, -1, insn.getArg(0), LiteralArg.TRUE);
IfCondition condition = IfCondition.fromIfNode(ifNode);
InsnArg zero = new LiteralArg(0, type);
long litVal = 1;
if (type == ArgType.DOUBLE) {
litVal = DOUBLE_TO_BITS;
} else if (type == ArgType.FLOAT) {
litVal = FLOAT_TO_BITS;
}
InsnArg one = new LiteralArg(litVal, type);
TernaryInsn ternary = new TernaryInsn(condition, insn.getResult(), one, zero);
replaceInsn(block, i, ternary);
}
}
}
private static void replaceConst(ClassNode parentClass, BlockNode block, int i, InsnNode insn) {
FieldNode f;
if (insn.getType() == InsnType.CONST_STR) {
String s = ((ConstStringNode) insn).getString();
f = parentClass.getConstField(s);
} else if (insn.getType() == InsnType.CONST_CLASS) {
ArgType t = ((ConstClassNode) insn).getClsType();
f = parentClass.getConstField(t);
} else {
f = parentClass.getConstFieldByLiteralArg((LiteralArg) insn.getArg(0));
}
if (f != null) {
InsnNode inode = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0);
inode.setResult(insn.getResult());
replaceInsn(block, i, inode);
}
}
private static void processArith(ClassNode parentClass, ArithNode arithNode) {
if (arithNode.getArgsCount() != 2) {
throw new JadxRuntimeException("Invalid args count in insn: " + arithNode);
}
InsnArg litArg = arithNode.getArg(1);
if (litArg.isLiteral()) {
FieldNode f = parentClass.getConstFieldByLiteralArg((LiteralArg) litArg);
if (f != null) {
InsnNode fGet = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0);
arithNode.replaceArg(litArg, InsnArg.wrapArg(fGet));
}
}
}
private static boolean checkArrSizes(MethodNode mth, NewArrayNode newArrInsn, FillArrayNode fillArrInsn) {
int dataSize = fillArrInsn.getSize();
InsnArg arrSizeArg = newArrInsn.getArg(0);
......@@ -197,6 +212,18 @@ public class ModVisitor extends AbstractVisitor {
return false;
}
private static void removeRedundantCast(MethodNode mth, BlockNode block, int i, IndexInsnNode insn) {
InsnArg castArg = insn.getArg(0);
ArgType castType = (ArgType) insn.getIndex();
if (!ArgType.isCastNeeded(mth.dex(), castArg.getType(), castType)
|| isCastDuplicate(insn)) {
InsnNode insnNode = new InsnNode(InsnType.MOVE, 1);
insnNode.setResult(insn.getResult());
insnNode.addArg(castArg);
replaceInsn(block, i, insnNode);
}
}
private static boolean isCastDuplicate(IndexInsnNode castInsn) {
InsnArg arg = castInsn.getArg(0);
if (arg.isRegister()) {
......
......@@ -160,7 +160,8 @@ public class PrepareForCodeGen extends AbstractVisitor {
List<InsnNode> list = block.getInstructions();
for (InsnNode insn : list) {
if (insn.getType() == InsnType.ARITH
&& !insn.contains(AFlag.DECLARE_VAR)) { // TODO: move this modify before ProcessVariable
&& !insn.contains(AFlag.ARITH_ONEARG)
&& !insn.contains(AFlag.DECLARE_VAR)) {
RegisterArg res = insn.getResult();
InsnArg arg = insn.getArg(0);
boolean replace = false;
......@@ -171,6 +172,7 @@ public class PrepareForCodeGen extends AbstractVisitor {
replace = res.sameCodeVar(regArg);
}
if (replace) {
insn.setResult(null);
insn.add(AFlag.ARITH_ONEARG);
}
}
......
......@@ -4,7 +4,6 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -24,7 +23,6 @@ import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.InvokeType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.FieldArg;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
......@@ -37,6 +35,7 @@ import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;
......@@ -51,12 +50,19 @@ public class SimplifyVisitor extends AbstractVisitor {
if (mth.isNoCode()) {
return;
}
boolean changed = false;
for (BlockNode block : mth.getBasicBlocks()) {
simplifyBlock(mth, block);
if (simplifyBlock(mth, block)) {
changed = true;
}
}
if (changed) {
CodeShrinkVisitor.shrinkMethod(mth);
}
}
private static void simplifyBlock(MethodNode mth, BlockNode block) {
private static boolean simplifyBlock(MethodNode mth, BlockNode block) {
boolean changed = false;
List<InsnNode> list = block.getInstructions();
for (int i = 0; i < list.size(); i++) {
InsnNode insn = list.get(i);
......@@ -75,10 +81,12 @@ public class SimplifyVisitor extends AbstractVisitor {
if (list.size() < insnCount) {
// some insns removed => restart block processing
simplifyBlock(mth, block);
return;
return true;
}
changed = true;
}
}
return changed;
}
private static InsnNode simplifyInsn(MethodNode mth, BlockNode block, InsnNode insn) {
......@@ -109,7 +117,7 @@ public class SimplifyVisitor extends AbstractVisitor {
case IPUT:
case SPUT:
return convertFieldArith(mth, block, insn);
return convertFieldArith(mth, insn);
case CHECK_CAST:
return processCast(mth, insn);
......@@ -404,9 +412,9 @@ public class SimplifyVisitor extends AbstractVisitor {
/**
* Convert field arith operation to arith instruction
* (IPUT = ARITH (IGET, lit) -> ARITH (fieldArg <op>= lit))
* (IPUT (ARITH (IGET, lit)) -> ARITH ((IGET)) <op>= lit))
*/
private static ArithNode convertFieldArith(MethodNode mth, BlockNode block, InsnNode insn) {
private static ArithNode convertFieldArith(MethodNode mth, InsnNode insn) {
InsnArg arg = insn.getArg(0);
if (!arg.isInsnWrap()) {
return null;
......@@ -428,49 +436,30 @@ public class SimplifyVisitor extends AbstractVisitor {
return null;
}
try {
InsnArg reg = null;
if (getType == InsnType.IGET) {
reg = get.getArg(0);
if (getType == InsnType.IGET && insn.getType() == InsnType.IPUT) {
InsnArg reg = get.getArg(0);
InsnArg putReg = insn.getArg(1);
if (!reg.equals(putReg)) {
return null;
}
}
reg = inlineFieldGet(reg, block, get, insn);
FieldArg fArg = new FieldArg(field, reg);
InsnArg fArg = InsnArg.wrapArg(get);
if (insn.getType() == InsnType.IPUT) {
InsnRemover.unbindArgUsage(mth, insn.getArg(1));
}
if (wrapType == InsnType.ARITH) {
ArithNode ar = (ArithNode) wrap;
return new ArithNode(ar.getOp(), fArg, ar.getArg(1));
return ArithNode.oneArgOp(ar.getOp(), fArg, ar.getArg(1));
}
int argsCount = wrap.getArgsCount();
InsnNode concat = new InsnNode(InsnType.STR_CONCAT, argsCount - 1);
for (int i = 1; i < argsCount; i++) {
concat.addArg(wrap.getArg(i));
}
return new ArithNode(ArithOp.ADD, fArg, InsnArg.wrapArg(concat));
return ArithNode.oneArgOp(ArithOp.ADD, fArg, InsnArg.wrapArg(concat));
} catch (Exception e) {
LOG.debug("Can't convert field arith insn: {}, mth: {}", insn, mth, e);
}
return null;
}
private static InsnArg inlineFieldGet(@Nullable InsnArg arg, BlockNode block, InsnNode get, InsnNode insn) {
if (arg == null || !arg.isRegister()) {
return arg;
}
RegisterArg reg = (RegisterArg) arg;
InsnNode assignInsn = reg.getAssignInsn();
SSAVar ssaVar = reg.getSVar();
if (ssaVar.getUseCount() == 2 && !ssaVar.isUsedInPhi() && assignInsn != null) {
List<RegisterArg> useList = ssaVar.getUseList();
if (useList.get(0).getParentInsn() == get && useList.get(1).getParentInsn() == insn) {
InsnType assignInsnType = assignInsn.getType();
if (assignInsnType == InsnType.IGET || assignInsnType == InsnType.SGET) {
InsnList.remove(block, assignInsn);
return InsnArg.wrapArg(assignInsn);
}
}
}
return arg;
}
}
......@@ -118,7 +118,7 @@ final class ArgsInfo {
return false;
}
RegisterArg result = insn.getResult();
if (result == null || result.isField()) {
if (result == null) {
return false;
}
return args.get(result.getRegNum());
......
......@@ -18,9 +18,9 @@ public class TestArith3 extends IntegrationTest {
public void test(byte[] buffer) {
int n = ((buffer[3] & 255) + 4) + ((buffer[2] & 15) << 8);
while (n + 4 < buffer.length) {
int c = buffer[n] & 255;
int p = (buffer[n + 2] & 255) + ((buffer[n + 1] & 31) << 8);
int len = (buffer[n + 4] & 255) + ((buffer[n + 3] & 15) << 8);
int c = buffer[n] & 255;
switch (c) {
case 27:
this.vp = p;
......@@ -37,7 +37,7 @@ public class TestArith3 extends IntegrationTest {
String code = cls.getCode().toString();
assertThat(code, containsOne("while (n + 4 < buffer.length) {"));
assertThat(code, containsOne("n += len + 5;"));
assertThat(code, containsOne(indent() + "n += len + 5;"));
assertThat(code, not(containsString("; n += len + 5) {")));
assertThat(code, not(containsString("default:")));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment