Commit a530371b authored by Skylot's avatar Skylot

fix: improve StringBuilder elimination (#704)

parent 0c5a83c0
......@@ -31,14 +31,13 @@ public class InsnNode extends LineAttrNode {
protected int offset;
public InsnNode(InsnType type, int argsCount) {
this(type, argsCount == 0 ? Collections.emptyList() : new ArrayList<>(argsCount));
}
public InsnNode(InsnType type, List<InsnArg> args) {
this.insnType = type;
this.arguments = args;
this.offset = -1;
if (argsCount == 0) {
this.arguments = Collections.emptyList();
} else {
this.arguments = new ArrayList<>(argsCount);
}
}
public static InsnNode wrapArg(InsnArg arg) {
......
......@@ -15,7 +15,6 @@ import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.CallMthInterface;
import jadx.core.dex.instructions.ConstStringNode;
import jadx.core.dex.instructions.FilledNewArrayNode;
import jadx.core.dex.instructions.IfNode;
......@@ -28,6 +27,8 @@ import jadx.core.dex.instructions.args.FieldArg;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
......@@ -35,6 +36,9 @@ import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;
public class SimplifyVisitor extends AbstractVisitor {
......@@ -232,77 +236,136 @@ public class SimplifyVisitor extends AbstractVisitor {
* Those chains are usually automatically generated by the Java compiler when you create String
* concatenations like <code>"text " + 1 + " text"</code>.
*/
@SuppressWarnings("InnerAssignment") // TODO
private static InsnNode convertInvoke(MethodNode mth, InvokeNode insn) {
MethodInfo callMth = insn.getCallMth();
// If this is a 'new StringBuilder(xxx).append(yyy).append(zzz).toString(),
// convert it to STRING_CONCAT pseudo instruction.
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE)
&& insn.getArg(0).isInsnWrap()) {
try {
List<InsnNode> chain = flattenInsnChain(insn);
int constrIndex = -1; // RAF
// Case where new StringBuilder() is called with NO args (the entire
// string is created using .append() calls:
if (chain.size() > 1 && chain.get(0).getType() == InsnType.CONSTRUCTOR) {
constrIndex = 0;
} else if (chain.size() > 2 && chain.get(1).getType() == InsnType.CONSTRUCTOR) {
// RAF Case where the first string element is String arg to the
// new StringBuilder("xxx") constructor
constrIndex = 1;
} else if (chain.size() > 3 && chain.get(2).getType() == InsnType.CONSTRUCTOR) {
// RAF Case where the first string element is String.valueOf() arg
// to the new StringBuilder(String.valueOf(zzz)) constructor
constrIndex = 2;
&& callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE)) {
InsnArg instanceArg = insn.getArg(0);
if (instanceArg.isInsnWrap()) {
// Convert 'new StringBuilder(xxx).append(yyy).append(zzz).toString() to STRING_CONCAT insn
List<InsnNode> callChain = flattenInsnChainUntil(insn, InsnType.CONSTRUCTOR);
return convertStringBuilderChain(mth, insn, callChain);
}
if (instanceArg.isRegister()) {
// Convert 'StringBuilder sb = new StringBuilder(xxx); sb.append(yyy); String str = sb.toString();'
List<InsnNode> useChain = collectUseChain(mth, insn, (RegisterArg) instanceArg);
return convertStringBuilderChain(mth, insn, useChain);
}
}
return null;
}
private static List<InsnNode> collectUseChain(MethodNode mth, InvokeNode insn, RegisterArg instanceArg) {
SSAVar sVar = instanceArg.getSVar();
if (sVar.isUsedInPhi() || sVar.getUseCount() == 0) {
return Collections.emptyList();
}
List<InsnNode> useChain = new ArrayList<>(sVar.getUseCount() + 1);
InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null) {
return Collections.emptyList();
}
useChain.add(assignInsn);
for (RegisterArg reg : sVar.getUseList()) {
InsnNode parentInsn = reg.getParentInsn();
if (parentInsn == null) {
return Collections.emptyList();
}
useChain.add(parentInsn);
}
int toStrIdx = InsnList.getIndex(useChain, insn);
if (useChain.size() - 1 != toStrIdx) {
return Collections.emptyList();
}
useChain.remove(toStrIdx);
// all insns must be in one block and sequential
BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn);
if (assignBlock == null) {
return Collections.emptyList();
}
List<InsnNode> blockInsns = assignBlock.getInstructions();
int assignIdx = InsnList.getIndex(blockInsns, assignInsn);
int chainSize = useChain.size();
int lastInsn = blockInsns.size() - assignIdx;
if (lastInsn < chainSize) {
return Collections.emptyList();
}
for (int i = 1; i < chainSize; i++) {
if (blockInsns.get(assignIdx + i) != useChain.get(i)) {
return Collections.emptyList();
}
}
return useChain;
}
private static InsnNode convertStringBuilderChain(MethodNode mth, InvokeNode toStrInsn, List<InsnNode> chain) {
try {
int chainSize = chain.size();
if (chainSize < 2) {
return null;
}
List<InsnArg> args = new ArrayList<>(chainSize);
InsnNode firstInsn = chain.get(0);
if (firstInsn.getType() != InsnType.CONSTRUCTOR) {
return null;
}
ConstructorInsn constrInsn = (ConstructorInsn) firstInsn;
if (constrInsn.getArgsCount() == 1) {
ArgType argType = constrInsn.getCallMth().getArgumentsTypes().get(0);
if (!argType.isObject()) {
return null;
}
args.add(constrInsn.getArg(0));
}
for (int i = 1; i < chainSize; i++) {
InsnNode chainInsn = chain.get(i);
InsnArg arg = getArgFromAppend(chainInsn);
if (arg == null) {
return null;
}
args.add(arg);
}
InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, args);
concatInsn.setResult(toStrInsn.getResult());
concatInsn.copyAttributesFrom(toStrInsn);
if (constrIndex != -1) { // If we found a CONSTRUCTOR, is it a StringBuilder?
ConstructorInsn constr = (ConstructorInsn) chain.get(constrIndex);
if (constr.getClassType().getFullName().equals(Consts.CLASS_STRING_BUILDER)) {
int len = chain.size();
int argInd = 1;
InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, len - 1);
InsnNode argInsn;
if (constrIndex > 0) { // There was an arg to the StringBuilder constr
InsnWrapArg iwa;
if (constrIndex == 2
&& (argInsn = chain.get(1)).getType() == InsnType.INVOKE
&& ((InvokeNode) argInsn).getCallMth().getName().compareTo("valueOf") == 0) {
// The argument of new StringBuilder() is a String.valueOf(chainElement0)
iwa = (InsnWrapArg) argInsn.getArg(0);
argInd = 3; // Cause for loop below to skip to after the constructor
} else {
InsnNode firstNode = chain.get(0);
if (firstNode instanceof ConstStringNode) {
ConstStringNode csn = (ConstStringNode) firstNode;
iwa = new InsnWrapArg(csn);
argInd = 2; // Cause for loop below to skip to after the constructor
} else {
return null;
}
}
concatInsn.addArg(iwa);
}
InsnRemover insnRemover = new InsnRemover(mth);
for (InsnNode insnNode : chain) {
insnRemover.addAndUnbind(insnNode);
}
insnRemover.perform();
for (; argInd < len; argInd++) { // Add the .append(xxx) arg string to concat
InsnNode node = chain.get(argInd);
MethodInfo method = ((CallMthInterface) node).getCallMth();
if (!(node.getArgsCount() < 2 && method.isConstructor() || method.getName().equals("append"))) {
// The chain contains other calls to StringBuilder methods than the constructor or append.
// We can't simplify such chains, therefore we leave them as they are.
return null;
}
// process only constructor and append() calls
concatInsn.addArg(node.getArg(1));
}
concatInsn.setResult(insn.getResult());
return concatInsn;
} // end of if constructor is for StringBuilder
} // end of if we found a constructor early in the chain
} catch (Exception e) {
LOG.warn("Can't convert string concatenation: {} insn: {}", mth, insn, e);
return concatInsn;
} catch (Exception e) {
LOG.warn("Can't convert string concatenation: {} insn: {}", mth, toStrInsn, e);
}
return null;
}
private static List<InsnNode> flattenInsnChainUntil(InsnNode insn, InsnType insnType) {
List<InsnNode> chain = new ArrayList<>();
InsnArg arg = insn.getArg(0);
while (arg.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
chain.add(wrapInsn);
if (wrapInsn.getType() == insnType
|| wrapInsn.getArgsCount() == 0) {
break;
}
arg = wrapInsn.getArg(0);
}
Collections.reverse(chain);
return chain;
}
private static InsnArg getArgFromAppend(InsnNode chainInsn) {
if (chainInsn.getType() == InsnType.INVOKE && chainInsn.getArgsCount() == 2) {
MethodInfo callMth = ((InvokeNode) chainInsn).getCallMth();
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& callMth.getName().equals("append")) {
return chainInsn.getArg(1);
}
}
return null;
......@@ -392,20 +455,4 @@ public class SimplifyVisitor extends AbstractVisitor {
}
return null;
}
private static List<InsnNode> flattenInsnChain(InsnNode insn) {
List<InsnNode> chain = new ArrayList<>();
InsnArg i = insn.getArg(0);
while (i.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) i).getWrapInsn();
chain.add(wrapInsn);
if (wrapInsn.getArgsCount() == 0) {
break;
}
i = wrapInsn.getArg(0);
}
Collections.reverse(chain);
return chain;
}
}
......@@ -165,6 +165,7 @@ public class BlockUtils {
return insns.get(insns.size() - 1);
}
@Nullable
public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) {
if (insn instanceof PhiInsn) {
return searchBlockWithPhi(mth, (PhiInsn) insn);
......
......@@ -4,6 +4,8 @@ import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
......@@ -22,6 +24,7 @@ public class InsnRemover {
private final MethodNode mth;
private final List<InsnNode> toRemove;
@Nullable
private List<InsnNode> instrList;
public InsnRemover(MethodNode mth) {
......@@ -53,7 +56,13 @@ public class InsnRemover {
if (toRemove.isEmpty()) {
return;
}
removeAll(instrList, toRemove);
if (instrList == null) {
for (InsnNode remInsn : toRemove) {
remove(mth, remInsn);
}
} else {
removeAll(instrList, toRemove);
}
toRemove.clear();
}
......
package jadx.tests.integration.others;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
public class TestStringBuilderElimination3 extends IntegrationTest {
public static class TestCls {
public static String test(String a) {
StringBuilder sb = new StringBuilder();
sb.append("result = ");
sb.append(a);
return sb.toString();
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("return \"result = \" + a;"));
assertThat(code, not(containsString("new StringBuilder()")));
}
public static class TestClsNegative {
private String f = "first";
public String test() {
StringBuilder sb = new StringBuilder();
sb.append("before = ");
sb.append(this.f);
updateF();
sb.append(", after = ");
sb.append(this.f);
return sb.toString();
}
private void updateF() {
this.f = "second";
}
public void check() {
assertThat(test(), is("before = first, after = second"));
}
}
@Test
public void testNegative() {
ClassNode cls = getClassNode(TestClsNegative.class);
String code = cls.getCode().toString();
assertThat(code, containsString("return sb.toString();"));
assertThat(code, containsString("new StringBuilder()"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment