Commit a530371b authored by Skylot's avatar Skylot

fix: improve StringBuilder elimination (#704)

parent 0c5a83c0
...@@ -31,14 +31,13 @@ public class InsnNode extends LineAttrNode { ...@@ -31,14 +31,13 @@ public class InsnNode extends LineAttrNode {
protected int offset; protected int offset;
public InsnNode(InsnType type, int argsCount) { public InsnNode(InsnType type, int argsCount) {
this(type, argsCount == 0 ? Collections.emptyList() : new ArrayList<>(argsCount));
}
public InsnNode(InsnType type, List<InsnArg> args) {
this.insnType = type; this.insnType = type;
this.arguments = args;
this.offset = -1; this.offset = -1;
if (argsCount == 0) {
this.arguments = Collections.emptyList();
} else {
this.arguments = new ArrayList<>(argsCount);
}
} }
public static InsnNode wrapArg(InsnArg arg) { public static InsnNode wrapArg(InsnArg arg) {
......
...@@ -15,7 +15,6 @@ import jadx.core.dex.info.FieldInfo; ...@@ -15,7 +15,6 @@ import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.info.MethodInfo; import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.ArithNode; import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp; import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.CallMthInterface;
import jadx.core.dex.instructions.ConstStringNode; import jadx.core.dex.instructions.ConstStringNode;
import jadx.core.dex.instructions.FilledNewArrayNode; import jadx.core.dex.instructions.FilledNewArrayNode;
import jadx.core.dex.instructions.IfNode; import jadx.core.dex.instructions.IfNode;
...@@ -28,6 +27,8 @@ import jadx.core.dex.instructions.args.FieldArg; ...@@ -28,6 +27,8 @@ import jadx.core.dex.instructions.args.FieldArg;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg; import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.instructions.mods.ConstructorInsn; import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn; import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
...@@ -35,6 +36,9 @@ import jadx.core.dex.nodes.InsnNode; ...@@ -35,6 +36,9 @@ import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode; import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.regions.conditions.IfCondition; import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;
public class SimplifyVisitor extends AbstractVisitor { public class SimplifyVisitor extends AbstractVisitor {
...@@ -232,77 +236,136 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -232,77 +236,136 @@ public class SimplifyVisitor extends AbstractVisitor {
* Those chains are usually automatically generated by the Java compiler when you create String * Those chains are usually automatically generated by the Java compiler when you create String
* concatenations like <code>"text " + 1 + " text"</code>. * concatenations like <code>"text " + 1 + " text"</code>.
*/ */
@SuppressWarnings("InnerAssignment") // TODO
private static InsnNode convertInvoke(MethodNode mth, InvokeNode insn) { private static InsnNode convertInvoke(MethodNode mth, InvokeNode insn) {
MethodInfo callMth = insn.getCallMth(); MethodInfo callMth = insn.getCallMth();
// If this is a 'new StringBuilder(xxx).append(yyy).append(zzz).toString(),
// convert it to STRING_CONCAT pseudo instruction.
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER) if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE) && callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE)) {
&& insn.getArg(0).isInsnWrap()) { InsnArg instanceArg = insn.getArg(0);
try { if (instanceArg.isInsnWrap()) {
List<InsnNode> chain = flattenInsnChain(insn); // Convert 'new StringBuilder(xxx).append(yyy).append(zzz).toString() to STRING_CONCAT insn
int constrIndex = -1; // RAF List<InsnNode> callChain = flattenInsnChainUntil(insn, InsnType.CONSTRUCTOR);
// Case where new StringBuilder() is called with NO args (the entire return convertStringBuilderChain(mth, insn, callChain);
// string is created using .append() calls: }
if (chain.size() > 1 && chain.get(0).getType() == InsnType.CONSTRUCTOR) { if (instanceArg.isRegister()) {
constrIndex = 0; // Convert 'StringBuilder sb = new StringBuilder(xxx); sb.append(yyy); String str = sb.toString();'
} else if (chain.size() > 2 && chain.get(1).getType() == InsnType.CONSTRUCTOR) { List<InsnNode> useChain = collectUseChain(mth, insn, (RegisterArg) instanceArg);
// RAF Case where the first string element is String arg to the return convertStringBuilderChain(mth, insn, useChain);
// new StringBuilder("xxx") constructor }
constrIndex = 1; }
} else if (chain.size() > 3 && chain.get(2).getType() == InsnType.CONSTRUCTOR) { return null;
// RAF Case where the first string element is String.valueOf() arg }
// to the new StringBuilder(String.valueOf(zzz)) constructor
constrIndex = 2; private static List<InsnNode> collectUseChain(MethodNode mth, InvokeNode insn, RegisterArg instanceArg) {
SSAVar sVar = instanceArg.getSVar();
if (sVar.isUsedInPhi() || sVar.getUseCount() == 0) {
return Collections.emptyList();
}
List<InsnNode> useChain = new ArrayList<>(sVar.getUseCount() + 1);
InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null) {
return Collections.emptyList();
}
useChain.add(assignInsn);
for (RegisterArg reg : sVar.getUseList()) {
InsnNode parentInsn = reg.getParentInsn();
if (parentInsn == null) {
return Collections.emptyList();
}
useChain.add(parentInsn);
}
int toStrIdx = InsnList.getIndex(useChain, insn);
if (useChain.size() - 1 != toStrIdx) {
return Collections.emptyList();
}
useChain.remove(toStrIdx);
// all insns must be in one block and sequential
BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn);
if (assignBlock == null) {
return Collections.emptyList();
}
List<InsnNode> blockInsns = assignBlock.getInstructions();
int assignIdx = InsnList.getIndex(blockInsns, assignInsn);
int chainSize = useChain.size();
int lastInsn = blockInsns.size() - assignIdx;
if (lastInsn < chainSize) {
return Collections.emptyList();
}
for (int i = 1; i < chainSize; i++) {
if (blockInsns.get(assignIdx + i) != useChain.get(i)) {
return Collections.emptyList();
}
}
return useChain;
}
private static InsnNode convertStringBuilderChain(MethodNode mth, InvokeNode toStrInsn, List<InsnNode> chain) {
try {
int chainSize = chain.size();
if (chainSize < 2) {
return null;
}
List<InsnArg> args = new ArrayList<>(chainSize);
InsnNode firstInsn = chain.get(0);
if (firstInsn.getType() != InsnType.CONSTRUCTOR) {
return null;
}
ConstructorInsn constrInsn = (ConstructorInsn) firstInsn;
if (constrInsn.getArgsCount() == 1) {
ArgType argType = constrInsn.getCallMth().getArgumentsTypes().get(0);
if (!argType.isObject()) {
return null;
} }
args.add(constrInsn.getArg(0));
}
for (int i = 1; i < chainSize; i++) {
InsnNode chainInsn = chain.get(i);
InsnArg arg = getArgFromAppend(chainInsn);
if (arg == null) {
return null;
}
args.add(arg);
}
InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, args);
concatInsn.setResult(toStrInsn.getResult());
concatInsn.copyAttributesFrom(toStrInsn);
if (constrIndex != -1) { // If we found a CONSTRUCTOR, is it a StringBuilder? InsnRemover insnRemover = new InsnRemover(mth);
ConstructorInsn constr = (ConstructorInsn) chain.get(constrIndex); for (InsnNode insnNode : chain) {
if (constr.getClassType().getFullName().equals(Consts.CLASS_STRING_BUILDER)) { insnRemover.addAndUnbind(insnNode);
int len = chain.size(); }
int argInd = 1; insnRemover.perform();
InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, len - 1);
InsnNode argInsn;
if (constrIndex > 0) { // There was an arg to the StringBuilder constr
InsnWrapArg iwa;
if (constrIndex == 2
&& (argInsn = chain.get(1)).getType() == InsnType.INVOKE
&& ((InvokeNode) argInsn).getCallMth().getName().compareTo("valueOf") == 0) {
// The argument of new StringBuilder() is a String.valueOf(chainElement0)
iwa = (InsnWrapArg) argInsn.getArg(0);
argInd = 3; // Cause for loop below to skip to after the constructor
} else {
InsnNode firstNode = chain.get(0);
if (firstNode instanceof ConstStringNode) {
ConstStringNode csn = (ConstStringNode) firstNode;
iwa = new InsnWrapArg(csn);
argInd = 2; // Cause for loop below to skip to after the constructor
} else {
return null;
}
}
concatInsn.addArg(iwa);
}
for (; argInd < len; argInd++) { // Add the .append(xxx) arg string to concat return concatInsn;
InsnNode node = chain.get(argInd); } catch (Exception e) {
MethodInfo method = ((CallMthInterface) node).getCallMth(); LOG.warn("Can't convert string concatenation: {} insn: {}", mth, toStrInsn, e);
if (!(node.getArgsCount() < 2 && method.isConstructor() || method.getName().equals("append"))) { }
// The chain contains other calls to StringBuilder methods than the constructor or append. return null;
// We can't simplify such chains, therefore we leave them as they are. }
return null;
} private static List<InsnNode> flattenInsnChainUntil(InsnNode insn, InsnType insnType) {
// process only constructor and append() calls List<InsnNode> chain = new ArrayList<>();
concatInsn.addArg(node.getArg(1)); InsnArg arg = insn.getArg(0);
} while (arg.isInsnWrap()) {
concatInsn.setResult(insn.getResult()); InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
return concatInsn; chain.add(wrapInsn);
} // end of if constructor is for StringBuilder if (wrapInsn.getType() == insnType
} // end of if we found a constructor early in the chain || wrapInsn.getArgsCount() == 0) {
} catch (Exception e) { break;
LOG.warn("Can't convert string concatenation: {} insn: {}", mth, insn, e); }
arg = wrapInsn.getArg(0);
}
Collections.reverse(chain);
return chain;
}
private static InsnArg getArgFromAppend(InsnNode chainInsn) {
if (chainInsn.getType() == InsnType.INVOKE && chainInsn.getArgsCount() == 2) {
MethodInfo callMth = ((InvokeNode) chainInsn).getCallMth();
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& callMth.getName().equals("append")) {
return chainInsn.getArg(1);
} }
} }
return null; return null;
...@@ -392,20 +455,4 @@ public class SimplifyVisitor extends AbstractVisitor { ...@@ -392,20 +455,4 @@ public class SimplifyVisitor extends AbstractVisitor {
} }
return null; return null;
} }
private static List<InsnNode> flattenInsnChain(InsnNode insn) {
List<InsnNode> chain = new ArrayList<>();
InsnArg i = insn.getArg(0);
while (i.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) i).getWrapInsn();
chain.add(wrapInsn);
if (wrapInsn.getArgsCount() == 0) {
break;
}
i = wrapInsn.getArg(0);
}
Collections.reverse(chain);
return chain;
}
} }
...@@ -165,6 +165,7 @@ public class BlockUtils { ...@@ -165,6 +165,7 @@ public class BlockUtils {
return insns.get(insns.size() - 1); return insns.get(insns.size() - 1);
} }
@Nullable
public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) { public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) {
if (insn instanceof PhiInsn) { if (insn instanceof PhiInsn) {
return searchBlockWithPhi(mth, (PhiInsn) insn); return searchBlockWithPhi(mth, (PhiInsn) insn);
......
...@@ -4,6 +4,8 @@ import java.util.ArrayList; ...@@ -4,6 +4,8 @@ import java.util.ArrayList;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
...@@ -22,6 +24,7 @@ public class InsnRemover { ...@@ -22,6 +24,7 @@ public class InsnRemover {
private final MethodNode mth; private final MethodNode mth;
private final List<InsnNode> toRemove; private final List<InsnNode> toRemove;
@Nullable
private List<InsnNode> instrList; private List<InsnNode> instrList;
public InsnRemover(MethodNode mth) { public InsnRemover(MethodNode mth) {
...@@ -53,7 +56,13 @@ public class InsnRemover { ...@@ -53,7 +56,13 @@ public class InsnRemover {
if (toRemove.isEmpty()) { if (toRemove.isEmpty()) {
return; return;
} }
removeAll(instrList, toRemove); if (instrList == null) {
for (InsnNode remInsn : toRemove) {
remove(mth, remInsn);
}
} else {
removeAll(instrList, toRemove);
}
toRemove.clear(); toRemove.clear();
} }
......
package jadx.tests.integration; package jadx.tests.integration.others;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
......
package jadx.tests.integration; package jadx.tests.integration.others;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
......
package jadx.tests.integration.others;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
public class TestStringBuilderElimination3 extends IntegrationTest {
public static class TestCls {
public static String test(String a) {
StringBuilder sb = new StringBuilder();
sb.append("result = ");
sb.append(a);
return sb.toString();
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("return \"result = \" + a;"));
assertThat(code, not(containsString("new StringBuilder()")));
}
public static class TestClsNegative {
private String f = "first";
public String test() {
StringBuilder sb = new StringBuilder();
sb.append("before = ");
sb.append(this.f);
updateF();
sb.append(", after = ");
sb.append(this.f);
return sb.toString();
}
private void updateF() {
this.f = "second";
}
public void check() {
assertThat(test(), is("before = first, after = second"));
}
}
@Test
public void testNegative() {
ClassNode cls = getClassNode(TestClsNegative.class);
String code = cls.getCode().toString();
assertThat(code, containsString("return sb.toString();"));
assertThat(code, containsString("new StringBuilder()"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment