Commit fd7d08cb authored by Skylot's avatar Skylot

feat: initial deboxing implementation (#717)

parent 3ae83594
...@@ -14,6 +14,7 @@ import jadx.api.JadxArgs; ...@@ -14,6 +14,7 @@ import jadx.api.JadxArgs;
import jadx.core.dex.visitors.ClassModifier; import jadx.core.dex.visitors.ClassModifier;
import jadx.core.dex.visitors.ConstInlineVisitor; import jadx.core.dex.visitors.ConstInlineVisitor;
import jadx.core.dex.visitors.ConstructorVisitor; import jadx.core.dex.visitors.ConstructorVisitor;
import jadx.core.dex.visitors.DeboxingVisitor;
import jadx.core.dex.visitors.DependencyCollector; import jadx.core.dex.visitors.DependencyCollector;
import jadx.core.dex.visitors.DotGraphVisitor; import jadx.core.dex.visitors.DotGraphVisitor;
import jadx.core.dex.visitors.EnumVisitor; import jadx.core.dex.visitors.EnumVisitor;
...@@ -87,6 +88,7 @@ public class Jadx { ...@@ -87,6 +88,7 @@ public class Jadx {
passes.add(new DebugInfoApplyVisitor()); passes.add(new DebugInfoApplyVisitor());
} }
passes.add(new DeboxingVisitor());
passes.add(new ModVisitor()); passes.add(new ModVisitor());
passes.add(new CodeShrinkVisitor()); passes.add(new CodeShrinkVisitor());
passes.add(new ReSugarCode()); passes.add(new ReSugarCode());
......
...@@ -147,7 +147,7 @@ public class AnnotationGen { ...@@ -147,7 +147,7 @@ public class AnnotationGen {
if (val instanceof String) { if (val instanceof String) {
code.add(getStringUtils().unescapeString((String) val)); code.add(getStringUtils().unescapeString((String) val));
} else if (val instanceof Integer) { } else if (val instanceof Integer) {
code.add(TypeGen.formatInteger((Integer) val)); code.add(TypeGen.formatInteger((Integer) val, false));
} else if (val instanceof Character) { } else if (val instanceof Character) {
code.add(getStringUtils().unescapeChar((Character) val)); code.add(getStringUtils().unescapeChar((Character) val));
} else if (val instanceof Boolean) { } else if (val instanceof Boolean) {
...@@ -157,11 +157,11 @@ public class AnnotationGen { ...@@ -157,11 +157,11 @@ public class AnnotationGen {
} else if (val instanceof Double) { } else if (val instanceof Double) {
code.add(TypeGen.formatDouble((Double) val)); code.add(TypeGen.formatDouble((Double) val));
} else if (val instanceof Long) { } else if (val instanceof Long) {
code.add(TypeGen.formatLong((Long) val)); code.add(TypeGen.formatLong((Long) val, false));
} else if (val instanceof Short) { } else if (val instanceof Short) {
code.add(TypeGen.formatShort((Short) val)); code.add(TypeGen.formatShort((Short) val, false));
} else if (val instanceof Byte) { } else if (val instanceof Byte) {
code.add(TypeGen.formatByte((Byte) val)); code.add(TypeGen.formatByte((Byte) val, false));
} else if (val instanceof ArgType) { } else if (val instanceof ArgType) {
classGen.useType(code, (ArgType) val); classGen.useType(code, (ArgType) val);
code.add(".class"); code.add(".class");
......
...@@ -132,7 +132,7 @@ public class InsnGen { ...@@ -132,7 +132,7 @@ public class InsnGen {
} }
private String lit(LiteralArg arg) { private String lit(LiteralArg arg) {
return TypeGen.literalToString(arg.getLiteral(), arg.getType(), mth, fallback); return TypeGen.literalToString(arg, mth, fallback);
} }
private void instanceField(CodeWriter code, FieldInfo field, InsnArg arg) throws CodegenException { private void instanceField(CodeWriter code, FieldInfo field, InsnArg arg) throws CodegenException {
......
...@@ -4,7 +4,9 @@ import org.slf4j.Logger; ...@@ -4,7 +4,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import jadx.core.deobf.NameMapper; import jadx.core.deobf.NameMapper;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.PrimitiveType; import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.nodes.IDexNode; import jadx.core.dex.nodes.IDexNode;
import jadx.core.utils.StringUtils; import jadx.core.utils.StringUtils;
...@@ -29,15 +31,25 @@ public class TypeGen { ...@@ -29,15 +31,25 @@ public class TypeGen {
} }
/** /**
* Convert literal arg to string (preferred method)
*/
public static String literalToString(LiteralArg arg, IDexNode dexNode, boolean fallback) {
return literalToString(arg.getLiteral(), arg.getType(),
dexNode.root().getStringUtils(),
fallback,
arg.contains(AFlag.EXPLICIT_PRIMITIVE_TYPE));
}
/**
* Convert literal value to string according to value type * Convert literal value to string according to value type
* *
* @throws JadxRuntimeException for incorrect type or literal value * @throws JadxRuntimeException for incorrect type or literal value
*/ */
public static String literalToString(long lit, ArgType type, IDexNode dexNode, boolean fallback) { public static String literalToString(long lit, ArgType type, IDexNode dexNode, boolean fallback) {
return literalToString(lit, type, dexNode.root().getStringUtils(), fallback); return literalToString(lit, type, dexNode.root().getStringUtils(), fallback, false);
} }
public static String literalToString(long lit, ArgType type, StringUtils stringUtils, boolean fallback) { public static String literalToString(long lit, ArgType type, StringUtils stringUtils, boolean fallback, boolean cast) {
if (type == null || !type.isTypeKnown()) { if (type == null || !type.isTypeKnown()) {
String n = Long.toString(lit); String n = Long.toString(lit);
if (fallback && Math.abs(lit) > 100) { if (fallback && Math.abs(lit) > 100) {
...@@ -65,13 +77,13 @@ public class TypeGen { ...@@ -65,13 +77,13 @@ public class TypeGen {
} }
return stringUtils.unescapeChar(ch); return stringUtils.unescapeChar(ch);
case BYTE: case BYTE:
return formatByte(lit); return formatByte(lit, cast);
case SHORT: case SHORT:
return formatShort(lit); return formatShort(lit, cast);
case INT: case INT:
return formatInteger(lit); return formatInteger(lit, cast);
case LONG: case LONG:
return formatLong(lit); return formatLong(lit, cast);
case FLOAT: case FLOAT:
return formatFloat(Float.intBitsToFloat((int) lit)); return formatFloat(Float.intBitsToFloat((int) lit));
case DOUBLE: case DOUBLE:
...@@ -90,37 +102,40 @@ public class TypeGen { ...@@ -90,37 +102,40 @@ public class TypeGen {
} }
} }
public static String formatShort(long l) { public static String formatShort(long l, boolean cast) {
if (l == Short.MAX_VALUE) { if (l == Short.MAX_VALUE) {
return "Short.MAX_VALUE"; return "Short.MAX_VALUE";
} }
if (l == Short.MIN_VALUE) { if (l == Short.MIN_VALUE) {
return "Short.MIN_VALUE"; return "Short.MIN_VALUE";
} }
return Long.toString(l); String str = Long.toString(l);
return cast ? "(short) " + str : str;
} }
public static String formatByte(long l) { public static String formatByte(long l, boolean cast) {
if (l == Byte.MAX_VALUE) { if (l == Byte.MAX_VALUE) {
return "Byte.MAX_VALUE"; return "Byte.MAX_VALUE";
} }
if (l == Byte.MIN_VALUE) { if (l == Byte.MIN_VALUE) {
return "Byte.MIN_VALUE"; return "Byte.MIN_VALUE";
} }
return Long.toString(l); String str = Long.toString(l);
return cast ? "(byte) " + str : str;
} }
public static String formatInteger(long l) { public static String formatInteger(long l, boolean cast) {
if (l == Integer.MAX_VALUE) { if (l == Integer.MAX_VALUE) {
return "Integer.MAX_VALUE"; return "Integer.MAX_VALUE";
} }
if (l == Integer.MIN_VALUE) { if (l == Integer.MIN_VALUE) {
return "Integer.MIN_VALUE"; return "Integer.MIN_VALUE";
} }
return Long.toString(l); String str = Long.toString(l);
return cast ? "(int) " + str : str;
} }
public static String formatLong(long l) { public static String formatLong(long l, boolean cast) {
if (l == Long.MAX_VALUE) { if (l == Long.MAX_VALUE) {
return "Long.MAX_VALUE"; return "Long.MAX_VALUE";
} }
...@@ -128,8 +143,8 @@ public class TypeGen { ...@@ -128,8 +143,8 @@ public class TypeGen {
return "Long.MIN_VALUE"; return "Long.MIN_VALUE";
} }
String str = Long.toString(l); String str = Long.toString(l);
if (Math.abs(l) >= Integer.MAX_VALUE) { if (cast || Math.abs(l) >= Integer.MAX_VALUE) {
str += 'L'; return str + 'L';
} }
return str; return str;
} }
......
...@@ -53,5 +53,10 @@ public enum AFlag { ...@@ -53,5 +53,10 @@ public enum AFlag {
EXPLICIT_GENERICS, EXPLICIT_GENERICS,
/**
* Use constants with explicit type: cast '(byte) 1' or type letter '7L'
*/
EXPLICIT_PRIMITIVE_TYPE,
INCONSISTENT_CODE, // warning about incorrect decompilation INCONSISTENT_CODE, // warning about incorrect decompilation
} }
...@@ -35,13 +35,12 @@ public final class MethodInfo { ...@@ -35,13 +35,12 @@ public final class MethodInfo {
private MethodInfo(ClassInfo declClass, String name, List<ArgType> args, ArgType retType) { private MethodInfo(ClassInfo declClass, String name, List<ArgType> args, ArgType retType) {
this.name = name; this.name = name;
alias = name; this.alias = name;
aliasFromPreset = false; this.aliasFromPreset = false;
this.declClass = declClass; this.declClass = declClass;
this.args = args; this.args = args;
this.retType = retType; this.retType = retType;
shortId = makeSignature(true); this.shortId = makeSignature(true);
} }
public static MethodInfo externalMth(ClassInfo declClass, String name, List<ArgType> args, ArgType retType) { public static MethodInfo externalMth(ClassInfo declClass, String name, List<ArgType> args, ArgType retType) {
......
package jadx.core.dex.instructions; package jadx.core.dex.instructions;
import jadx.core.dex.info.MethodInfo; import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.args.RegisterArg;
public interface CallMthInterface { public interface CallMthInterface {
MethodInfo getCallMth(); MethodInfo getCallMth();
RegisterArg getInstanceArg();
} }
package jadx.core.dex.instructions; package jadx.core.dex.instructions;
import org.jetbrains.annotations.Nullable;
import com.android.dx.io.instructions.DecodedInstruction; import com.android.dx.io.instructions.DecodedInstruction;
import jadx.core.dex.info.MethodInfo; import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils; import jadx.core.utils.InsnUtils;
import jadx.core.utils.Utils; import jadx.core.utils.Utils;
...@@ -52,6 +55,18 @@ public class InvokeNode extends InsnNode implements CallMthInterface { ...@@ -52,6 +55,18 @@ public class InvokeNode extends InsnNode implements CallMthInterface {
} }
@Override @Override
@Nullable
public RegisterArg getInstanceArg() {
if (type != InvokeType.STATIC && getArgsCount() > 0) {
InsnArg firstArg = getArg(0);
if (firstArg.isRegister()) {
return ((RegisterArg) firstArg);
}
}
return null;
}
@Override
public InsnNode copy() { public InsnNode copy() {
return copyCommonParams(new InvokeNode(mth, type, getArgsCount())); return copyCommonParams(new InvokeNode(mth, type, getArgsCount()));
} }
......
...@@ -77,7 +77,7 @@ public final class LiteralArg extends InsnArg { ...@@ -77,7 +77,7 @@ public final class LiteralArg extends InsnArg {
@Override @Override
public String toString() { public String toString() {
try { try {
String value = TypeGen.literalToString(literal, getType(), DEF_STRING_UTILS, true); String value = TypeGen.literalToString(literal, getType(), DEF_STRING_UTILS, true, false);
if (getType().equals(ArgType.BOOLEAN) && (value.equals("true") || value.equals("false"))) { if (getType().equals(ArgType.BOOLEAN) && (value.equals("true") || value.equals("false"))) {
return value; return value;
} }
......
...@@ -55,8 +55,7 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -55,8 +55,7 @@ public class RegisterArg extends InsnArg implements Named {
if (sVar != null) { if (sVar != null) {
return sVar.getTypeInfo().getType(); return sVar.getTypeInfo().getType();
} }
LOG.warn("Register type unknown, SSA variable not initialized: r{}", regNum); return ArgType.UNKNOWN;
return type;
} }
public ArgType getInitType() { public ArgType getInitType() {
......
...@@ -63,6 +63,7 @@ public class ConstructorInsn extends InsnNode implements CallMthInterface { ...@@ -63,6 +63,7 @@ public class ConstructorInsn extends InsnNode implements CallMthInterface {
return callMth; return callMth;
} }
@Override
public RegisterArg getInstanceArg() { public RegisterArg getInstanceArg() {
return instanceArg; return instanceArg;
} }
......
...@@ -40,6 +40,10 @@ public class ConstInlineVisitor extends AbstractVisitor { ...@@ -40,6 +40,10 @@ public class ConstInlineVisitor extends AbstractVisitor {
if (mth.isNoCode()) { if (mth.isNoCode()) {
return; return;
} }
process(mth);
}
public static void process(MethodNode mth) {
List<InsnNode> toRemove = new ArrayList<>(); List<InsnNode> toRemove = new ArrayList<>();
for (BlockNode block : mth.getBasicBlocks()) { for (BlockNode block : mth.getBasicBlocks()) {
toRemove.clear(); toRemove.clear();
...@@ -175,17 +179,19 @@ public class ConstInlineVisitor extends AbstractVisitor { ...@@ -175,17 +179,19 @@ public class ConstInlineVisitor extends AbstractVisitor {
if (constArg.isLiteral()) { if (constArg.isLiteral()) {
long literal = ((LiteralArg) constArg).getLiteral(); long literal = ((LiteralArg) constArg).getLiteral();
ArgType argType = arg.getInitType(); ArgType argType = arg.getType();
if (argType == ArgType.UNKNOWN) {
argType = arg.getInitType();
}
if (argType.isObject() && literal != 0) { if (argType.isObject() && literal != 0) {
argType = ArgType.NARROW_NUMBERS; argType = ArgType.NARROW_NUMBERS;
} }
LiteralArg litArg = InsnArg.lit(literal, argType); LiteralArg litArg = InsnArg.lit(literal, argType);
litArg.copyAttributesFrom(constArg);
if (!useInsn.replaceArg(arg, litArg)) { if (!useInsn.replaceArg(arg, litArg)) {
return false; return false;
} }
// arg replaced, made some optimizations // arg replaced, made some optimizations
litArg.setType(arg.getInitType());
FieldNode fieldNode = null; FieldNode fieldNode = null;
ArgType litArgType = litArg.getType(); ArgType litArgType = litArg.getType();
if (litArgType.isTypeKnown()) { if (litArgType.isTypeKnown()) {
......
package jadx.core.dex.visitors;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.InvokeType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.regions.variables.ProcessVariables;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.exceptions.JadxException;
/**
* Remove primitives boxing
* i.e convert 'Integer.valueOf(1)' to '1'
*/
@JadxVisitor(
name = "DeboxingVisitor",
desc = "Remove primitives boxing",
runBefore = {
CodeShrinkVisitor.class,
ProcessVariables.class
}
)
public class DeboxingVisitor extends AbstractVisitor {
private Set<MethodInfo> valueOfMths;
@Override
public void init(RootNode root) {
valueOfMths = new HashSet<>();
valueOfMths.add(valueOfMth(root, ArgType.INT, "java.lang.Integer"));
valueOfMths.add(valueOfMth(root, ArgType.BOOLEAN, "java.lang.Boolean"));
valueOfMths.add(valueOfMth(root, ArgType.BYTE, "java.lang.Byte"));
valueOfMths.add(valueOfMth(root, ArgType.SHORT, "java.lang.Short"));
valueOfMths.add(valueOfMth(root, ArgType.CHAR, "java.lang.Character"));
valueOfMths.add(valueOfMth(root, ArgType.LONG, "java.lang.Long"));
}
private static MethodInfo valueOfMth(RootNode root, ArgType argType, String clsName) {
ArgType boxType = ArgType.object(clsName);
ClassInfo boxCls = ClassInfo.fromType(root, boxType);
return MethodInfo.externalMth(boxCls, "valueOf", Collections.singletonList(argType), boxType);
}
@Override
public void visit(MethodNode mth) throws JadxException {
if (mth.isNoCode()) {
return;
}
boolean replaced = false;
for (BlockNode blockNode : mth.getBasicBlocks()) {
List<InsnNode> insnList = blockNode.getInstructions();
int count = insnList.size();
for (int i = 0; i < count; i++) {
InsnNode insnNode = insnList.get(i);
if (insnNode.getType() == InsnType.INVOKE) {
InsnNode replaceInsn = checkForReplace(((InvokeNode) insnNode));
if (replaceInsn != null) {
BlockUtils.replaceInsn(blockNode, i, replaceInsn);
replaced = true;
}
}
}
}
if (replaced) {
ConstInlineVisitor.process(mth);
}
}
private InsnNode checkForReplace(InvokeNode insnNode) {
if (insnNode.getInvokeType() != InvokeType.STATIC
|| insnNode.getResult() == null) {
return null;
}
MethodInfo callMth = insnNode.getCallMth();
if (valueOfMths.contains(callMth)) {
RegisterArg resArg = insnNode.getResult();
InsnArg arg = insnNode.getArg(0);
if (arg.isLiteral() && checkArgUsage(resArg)) {
ArgType primitiveType = callMth.getArgumentsTypes().get(0);
ArgType boxType = callMth.getReturnType();
if (isNeedExplicitCast(resArg, primitiveType, boxType)) {
arg.add(AFlag.EXPLICIT_PRIMITIVE_TYPE);
}
resArg.setType(primitiveType);
arg.setType(primitiveType);
InsnNode constInsn = new InsnNode(InsnType.CONST, 1);
constInsn.addArg(arg);
constInsn.setResult(resArg);
return constInsn;
}
}
return null;
}
private boolean isNeedExplicitCast(RegisterArg resArg, ArgType primitiveType, ArgType boxType) {
if (primitiveType == ArgType.LONG) {
return true;
}
if (primitiveType != ArgType.INT) {
Set<ArgType> useTypes = collectUseTypes(resArg);
useTypes.add(resArg.getType());
useTypes.remove(boxType);
useTypes.remove(primitiveType);
return !useTypes.isEmpty();
}
return false;
}
private boolean checkArgUsage(RegisterArg arg) {
for (RegisterArg useArg : arg.getSVar().getUseList()) {
InsnNode parentInsn = useArg.getParentInsn();
if (parentInsn == null) {
return false;
}
if (parentInsn.getType() == InsnType.INVOKE) {
InvokeNode invokeNode = (InvokeNode) parentInsn;
if (useArg.equals(invokeNode.getInstanceArg())) {
return false;
}
}
}
return true;
}
private Set<ArgType> collectUseTypes(RegisterArg arg) {
Set<ArgType> types = new HashSet<>();
for (RegisterArg useArg : arg.getSVar().getUseList()) {
types.add(useArg.getType());
types.add(useArg.getInitType());
}
return types;
}
}
...@@ -37,7 +37,7 @@ public class TestVarArg extends IntegrationTest { ...@@ -37,7 +37,7 @@ public class TestVarArg extends IntegrationTest {
assertThat(code, containsString("void test2(int i, Object... a) {")); assertThat(code, containsString("void test2(int i, Object... a) {"));
assertThat(code, containsString("test1(1, 2);")); assertThat(code, containsString("test1(1, 2);"));
assertThat(code, containsString("test2(3, \"1\", Integer.valueOf(7));")); assertThat(code, containsString("test2(3, \"1\", 7);"));
// negative case // negative case
assertThat(code, containsString("void test3(int[] a) {")); assertThat(code, containsString("void test3(int[] a) {"));
......
package jadx.tests.integration.others;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
public class TestDeboxing extends IntegrationTest {
public static class TestCls {
public Object testInt() {
return 1;
}
public Object testBoolean() {
return true;
}
public Object testByte() {
return (byte) 2;
}
public Short testShort() {
return 3;
}
public Character testChar() {
return 'c';
}
public Long testLong() {
return 4L;
}
public void testConstInline() {
Boolean v = true;
use(v);
use(v);
}
private void use(Boolean v) {
}
public void check() {
// don't mind weird comparisons
// need to get primitive without using boxing or literal
// otherwise will get same result after decompilation
assertThat(testInt(), is(Integer.sum(0, 1)));
assertThat(testBoolean(), is(Boolean.TRUE));
assertThat(testByte(), is(Byte.parseByte("2")));
assertThat(testShort(), is(Short.parseShort("3")));
assertThat(testChar(), is("c".charAt(0)));
assertThat(testLong(), is(Long.valueOf("4")));
testConstInline();
}
}
@Test
public void test() {
noDebugInfo();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("return 1;"));
assertThat(code, containsOne("return true;"));
assertThat(code, containsOne("return (byte) 2;"));
assertThat(code, containsOne("return 3;"));
assertThat(code, containsOne("return 'c';"));
assertThat(code, containsOne("return 4L;"));
assertThat(code, countString(2, "use(true);"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment