Commit 389caf18 authored by Skylot's avatar Skylot

fix: improve filled array detection

parent 5cee498e
...@@ -5,8 +5,6 @@ import java.util.Objects; ...@@ -5,8 +5,6 @@ import java.util.Objects;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.nodes.DexNode; import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils; import jadx.core.utils.InsnUtils;
...@@ -115,6 +113,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -115,6 +113,7 @@ public class RegisterArg extends InsnArg implements Named {
} }
} }
@Override
public RegisterArg duplicate() { public RegisterArg duplicate() {
return duplicate(getRegNum(), sVar); return duplicate(getRegNum(), sVar);
} }
...@@ -130,8 +129,6 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -130,8 +129,6 @@ public class RegisterArg extends InsnArg implements Named {
/** /**
* Return constant value from register assign or null if not constant * Return constant value from register assign or null if not constant
*
* @return LiteralArg, String or ArgType
*/ */
public Object getConstValue(DexNode dex) { public Object getConstValue(DexNode dex) {
InsnNode parInsn = getAssignInsn(); InsnNode parInsn = getAssignInsn();
...@@ -149,22 +146,19 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -149,22 +146,19 @@ public class RegisterArg extends InsnArg implements Named {
return sVar.getAssign().getParentInsn(); return sVar.getAssign().getParentInsn();
} }
public InsnNode getPhiAssignInsn() {
PhiInsn usePhi = sVar.getUsedInPhi();
if (usePhi != null) {
return usePhi;
}
InsnNode parent = sVar.getAssign().getParentInsn();
if (parent != null && parent.getType() == InsnType.PHI) {
return parent;
}
return null;
}
public boolean equalRegister(RegisterArg arg) { public boolean equalRegister(RegisterArg arg) {
return regNum == arg.regNum; return regNum == arg.regNum;
} }
public boolean sameRegAndSVar(InsnArg arg) {
if (!arg.isRegister()) {
return false;
}
RegisterArg reg = (RegisterArg) arg;
return regNum == reg.getRegNum()
&& Objects.equals(sVar, reg.getSVar());
}
public boolean equalRegisterAndType(RegisterArg arg) { public boolean equalRegisterAndType(RegisterArg arg) {
return regNum == arg.regNum && type.equals(arg.type); return regNum == arg.regNum && type.equals(arg.type);
} }
......
package jadx.core.dex.visitors; package jadx.core.dex.visitors;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
...@@ -21,13 +22,18 @@ import jadx.core.dex.instructions.args.ArgType; ...@@ -21,13 +22,18 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg; import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.DexNode; import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.FieldNode; import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.InstructionRemover; import jadx.core.utils.InstructionRemover;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.exceptions.JadxException;
@JadxVisitor( @JadxVisitor(
...@@ -56,65 +62,119 @@ public class ReSugarCode extends AbstractVisitor { ...@@ -56,65 +62,119 @@ public class ReSugarCode extends AbstractVisitor {
List<InsnNode> instructions = block.getInstructions(); List<InsnNode> instructions = block.getInstructions();
int size = instructions.size(); int size = instructions.size();
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
InsnNode replacedInsn = process(mth, instructions, i, remover); process(mth, instructions, i, remover);
if (replacedInsn != null) {
instructions.set(i, replacedInsn);
}
} }
remover.perform(); remover.perform();
} }
} }
private static InsnNode process(MethodNode mth, List<InsnNode> instructions, int i, InstructionRemover remover) { private static void process(MethodNode mth, List<InsnNode> instructions, int i, InstructionRemover remover) {
InsnNode insn = instructions.get(i); InsnNode insn = instructions.get(i);
if (insn.contains(AFlag.REMOVE)) {
return;
}
switch (insn.getType()) { switch (insn.getType()) {
case NEW_ARRAY: case NEW_ARRAY:
return processNewArray(mth, instructions, i, remover); processNewArray(mth, instructions, i, remover);
break;
case SWITCH: case SWITCH:
processEnumSwitch(mth, (SwitchNode) insn); processEnumSwitch(mth, (SwitchNode) insn);
return null; break;
default: default:
return null; break;
} }
} }
/** /**
* Replace new array and sequence of array-put to new filled-array instruction. * Replace new array and sequence of array-put to new filled-array instruction.
*/ */
private static InsnNode processNewArray(MethodNode mth, List<InsnNode> instructions, int i, private static void processNewArray(MethodNode mth, List<InsnNode> instructions, int i,
InstructionRemover remover) { InstructionRemover remover) {
NewArrayNode newArrayInsn = (NewArrayNode) instructions.get(i); NewArrayNode newArrayInsn = (NewArrayNode) instructions.get(i);
InsnArg arg = newArrayInsn.getArg(0); InsnArg arrLenArg = newArrayInsn.getArg(0);
if (!arg.isLiteral()) { if (!arrLenArg.isLiteral()) {
return null; return;
} }
int len = (int) ((LiteralArg) arg).getLiteral(); int len = (int) ((LiteralArg) arrLenArg).getLiteral();
int size = instructions.size(); if (len == 0) {
if (len <= 0 return;
|| i + len >= size }
|| instructions.get(i + len).getType() != InsnType.APUT) { RegisterArg arrArg = newArrayInsn.getResult();
return null; SSAVar ssaVar = arrArg.getSVar();
List<RegisterArg> useList = ssaVar.getUseList();
if (useList.size() < len) {
return;
}
// check sequential array put with increasing index
int putIndex = 0;
for (RegisterArg useArg : useList) {
InsnNode insn = useArg.getParentInsn();
if (checkPutInsn(mth, insn, arrArg, putIndex)) {
putIndex++;
} else {
break;
}
} }
for (int j = 0; j < len; j++) { if (putIndex != len) {
InsnNode put = instructions.get(i + 1 + j); return;
if (put.getType() != InsnType.APUT) { }
LOG.debug("Not a APUT in expected new filled array: {}, method: {}", put, mth); List<InsnNode> arrPuts = useList.subList(0, len).stream().map(InsnArg::getParentInsn).collect(Collectors.toList());
return null; // check that all puts in current block
for (InsnNode arrPut : arrPuts) {
int index = InsnList.getIndex(instructions, arrPut);
if (index == -1) {
if (LOG.isDebugEnabled()) {
LOG.debug("TODO: APUT found in different block: {}, mth: {}", arrPut, mth);
}
return;
} }
} }
// checks complete, apply // checks complete, apply
ArgType arrType = newArrayInsn.getArrayType(); ArgType arrType = newArrayInsn.getArrayType();
InsnNode filledArr = new FilledNewArrayNode(arrType.getArrayElement(), len); InsnNode filledArr = new FilledNewArrayNode(arrType.getArrayElement(), len);
filledArr.setResult(newArrayInsn.getResult().duplicate()); filledArr.setResult(arrArg.duplicate());
for (int j = 0; j < len; j++) { for (InsnNode put : arrPuts) {
InsnNode put = instructions.get(i + 1 + j);
filledArr.addArg(put.getArg(2).duplicate()); filledArr.addArg(put.getArg(2).duplicate());
remover.add(put); remover.addAndUnbind(mth, put);
}
remover.addAndUnbind(mth, newArrayInsn);
int replaceIndex = InsnList.getIndex(instructions, Utils.last(arrPuts));
instructions.set(replaceIndex, filledArr);
}
private static boolean checkPutInsn(MethodNode mth, InsnNode insn, RegisterArg arrArg, int putIndex) {
if (insn == null || insn.getType() != InsnType.APUT) {
return false;
}
if (!arrArg.sameRegAndSVar(insn.getArg(0))) {
return false;
}
InsnArg indexArg = insn.getArg(1);
int index = -1;
if (indexArg.isLiteral()) {
index = (int) ((LiteralArg) indexArg).getLiteral();
} else if (indexArg.isRegister()) {
RegisterArg reg = (RegisterArg) indexArg;
index = getIntConst(reg.getConstValue(mth.dex()));
} else if (indexArg.isInsnWrap()) {
InsnNode constInsn = ((InsnWrapArg) indexArg).getWrapInsn();
index = getIntConst(InsnUtils.getConstValueByInsn(mth.dex(), constInsn));
}
return index == putIndex;
}
private static int getIntConst(Object value) {
if (value instanceof Integer) {
return (Integer) value;
}
if (value instanceof Long) {
return ((Long) value).intValue();
} }
return filledArr; return -1;
} }
private static void processEnumSwitch(MethodNode mth, SwitchNode insn) { private static void processEnumSwitch(MethodNode mth, SwitchNode insn) {
......
...@@ -44,6 +44,10 @@ public class InstructionRemover { ...@@ -44,6 +44,10 @@ public class InstructionRemover {
public void add(InsnNode insn) { public void add(InsnNode insn) {
toRemove.add(insn); toRemove.add(insn);
} }
public void addAndUnbind(MethodNode mth, InsnNode insn) {
toRemove.add(insn);
unbindInsn(mth, insn);
}
public void perform() { public void perform() {
if (toRemove.isEmpty()) { if (toRemove.isEmpty()) {
...@@ -65,7 +69,7 @@ public class InstructionRemover { ...@@ -65,7 +69,7 @@ public class InstructionRemover {
} }
} }
unbindResult(mth, insn); unbindResult(mth, insn);
insn.add(AFlag.INCONSISTENT_CODE); insn.add(AFlag.REMOVE);
} }
public static void fixUsedInPhiFlag(RegisterArg useReg) { public static void fixUsedInPhiFlag(RegisterArg useReg) {
......
...@@ -11,6 +11,8 @@ import java.util.List; ...@@ -11,6 +11,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Function; import java.util.function.Function;
import org.jetbrains.annotations.Nullable;
import jadx.api.JadxDecompiler; import jadx.api.JadxDecompiler;
import jadx.core.codegen.CodeWriter; import jadx.core.codegen.CodeWriter;
...@@ -174,4 +176,12 @@ public class Utils { ...@@ -174,4 +176,12 @@ public class Utils {
} }
return Collections.unmodifiableMap(result); return Collections.unmodifiableMap(result);
} }
@Nullable
public static <T> T last(List<T> list) {
if (list.isEmpty()) {
return null;
}
return list.get(list.size() - 1);
}
} }
package jadx.tests.integration.arrays;
import org.junit.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertThat;
public class TestMultiDimArrayFill extends IntegrationTest {
public static class TestCls {
public static Obj test(int a, int b) {
return new Obj(
new int[][]{
new int[]{1},
new int[]{2},
{3},
new int[]{4, 5},
new int[0]
},
new int[]{a, a, a, a, b}
);
}
private static class Obj {
public Obj(int[][] ints, int[] ints2) {}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("return new Obj("
+ "new int[][]{new int[]{1}, new int[]{2}, new int[]{3}, new int[]{4, 5}, new int[0]}, "
+ "new int[]{a, a, a, a, b});"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment