Commit fd7d08cb authored by Skylot's avatar Skylot

feat: initial deboxing implementation (#717)

parent 3ae83594
......@@ -14,6 +14,7 @@ import jadx.api.JadxArgs;
import jadx.core.dex.visitors.ClassModifier;
import jadx.core.dex.visitors.ConstInlineVisitor;
import jadx.core.dex.visitors.ConstructorVisitor;
import jadx.core.dex.visitors.DeboxingVisitor;
import jadx.core.dex.visitors.DependencyCollector;
import jadx.core.dex.visitors.DotGraphVisitor;
import jadx.core.dex.visitors.EnumVisitor;
......@@ -87,6 +88,7 @@ public class Jadx {
passes.add(new DebugInfoApplyVisitor());
}
passes.add(new DeboxingVisitor());
passes.add(new ModVisitor());
passes.add(new CodeShrinkVisitor());
passes.add(new ReSugarCode());
......
......@@ -147,7 +147,7 @@ public class AnnotationGen {
if (val instanceof String) {
code.add(getStringUtils().unescapeString((String) val));
} else if (val instanceof Integer) {
code.add(TypeGen.formatInteger((Integer) val));
code.add(TypeGen.formatInteger((Integer) val, false));
} else if (val instanceof Character) {
code.add(getStringUtils().unescapeChar((Character) val));
} else if (val instanceof Boolean) {
......@@ -157,11 +157,11 @@ public class AnnotationGen {
} else if (val instanceof Double) {
code.add(TypeGen.formatDouble((Double) val));
} else if (val instanceof Long) {
code.add(TypeGen.formatLong((Long) val));
code.add(TypeGen.formatLong((Long) val, false));
} else if (val instanceof Short) {
code.add(TypeGen.formatShort((Short) val));
code.add(TypeGen.formatShort((Short) val, false));
} else if (val instanceof Byte) {
code.add(TypeGen.formatByte((Byte) val));
code.add(TypeGen.formatByte((Byte) val, false));
} else if (val instanceof ArgType) {
classGen.useType(code, (ArgType) val);
code.add(".class");
......
......@@ -132,7 +132,7 @@ public class InsnGen {
}
private String lit(LiteralArg arg) {
return TypeGen.literalToString(arg.getLiteral(), arg.getType(), mth, fallback);
return TypeGen.literalToString(arg, mth, fallback);
}
private void instanceField(CodeWriter code, FieldInfo field, InsnArg arg) throws CodegenException {
......
......@@ -4,7 +4,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.deobf.NameMapper;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.nodes.IDexNode;
import jadx.core.utils.StringUtils;
......@@ -29,15 +31,25 @@ public class TypeGen {
}
/**
* Convert literal arg to string (preferred method)
*/
public static String literalToString(LiteralArg arg, IDexNode dexNode, boolean fallback) {
return literalToString(arg.getLiteral(), arg.getType(),
dexNode.root().getStringUtils(),
fallback,
arg.contains(AFlag.EXPLICIT_PRIMITIVE_TYPE));
}
/**
* Convert literal value to string according to value type
*
* @throws JadxRuntimeException for incorrect type or literal value
*/
public static String literalToString(long lit, ArgType type, IDexNode dexNode, boolean fallback) {
return literalToString(lit, type, dexNode.root().getStringUtils(), fallback);
return literalToString(lit, type, dexNode.root().getStringUtils(), fallback, false);
}
public static String literalToString(long lit, ArgType type, StringUtils stringUtils, boolean fallback) {
public static String literalToString(long lit, ArgType type, StringUtils stringUtils, boolean fallback, boolean cast) {
if (type == null || !type.isTypeKnown()) {
String n = Long.toString(lit);
if (fallback && Math.abs(lit) > 100) {
......@@ -65,13 +77,13 @@ public class TypeGen {
}
return stringUtils.unescapeChar(ch);
case BYTE:
return formatByte(lit);
return formatByte(lit, cast);
case SHORT:
return formatShort(lit);
return formatShort(lit, cast);
case INT:
return formatInteger(lit);
return formatInteger(lit, cast);
case LONG:
return formatLong(lit);
return formatLong(lit, cast);
case FLOAT:
return formatFloat(Float.intBitsToFloat((int) lit));
case DOUBLE:
......@@ -90,37 +102,40 @@ public class TypeGen {
}
}
public static String formatShort(long l) {
public static String formatShort(long l, boolean cast) {
if (l == Short.MAX_VALUE) {
return "Short.MAX_VALUE";
}
if (l == Short.MIN_VALUE) {
return "Short.MIN_VALUE";
}
return Long.toString(l);
String str = Long.toString(l);
return cast ? "(short) " + str : str;
}
public static String formatByte(long l) {
public static String formatByte(long l, boolean cast) {
if (l == Byte.MAX_VALUE) {
return "Byte.MAX_VALUE";
}
if (l == Byte.MIN_VALUE) {
return "Byte.MIN_VALUE";
}
return Long.toString(l);
String str = Long.toString(l);
return cast ? "(byte) " + str : str;
}
public static String formatInteger(long l) {
public static String formatInteger(long l, boolean cast) {
if (l == Integer.MAX_VALUE) {
return "Integer.MAX_VALUE";
}
if (l == Integer.MIN_VALUE) {
return "Integer.MIN_VALUE";
}
return Long.toString(l);
String str = Long.toString(l);
return cast ? "(int) " + str : str;
}
public static String formatLong(long l) {
public static String formatLong(long l, boolean cast) {
if (l == Long.MAX_VALUE) {
return "Long.MAX_VALUE";
}
......@@ -128,8 +143,8 @@ public class TypeGen {
return "Long.MIN_VALUE";
}
String str = Long.toString(l);
if (Math.abs(l) >= Integer.MAX_VALUE) {
str += 'L';
if (cast || Math.abs(l) >= Integer.MAX_VALUE) {
return str + 'L';
}
return str;
}
......
......@@ -53,5 +53,10 @@ public enum AFlag {
EXPLICIT_GENERICS,
/**
* Use constants with explicit type: cast '(byte) 1' or type letter '7L'
*/
EXPLICIT_PRIMITIVE_TYPE,
INCONSISTENT_CODE, // warning about incorrect decompilation
}
......@@ -35,13 +35,12 @@ public final class MethodInfo {
private MethodInfo(ClassInfo declClass, String name, List<ArgType> args, ArgType retType) {
this.name = name;
alias = name;
aliasFromPreset = false;
this.alias = name;
this.aliasFromPreset = false;
this.declClass = declClass;
this.args = args;
this.retType = retType;
shortId = makeSignature(true);
this.shortId = makeSignature(true);
}
public static MethodInfo externalMth(ClassInfo declClass, String name, List<ArgType> args, ArgType retType) {
......
package jadx.core.dex.instructions;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.args.RegisterArg;
public interface CallMthInterface {
MethodInfo getCallMth();
RegisterArg getInstanceArg();
}
package jadx.core.dex.instructions;
import org.jetbrains.annotations.Nullable;
import com.android.dx.io.instructions.DecodedInstruction;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.Utils;
......@@ -52,6 +55,18 @@ public class InvokeNode extends InsnNode implements CallMthInterface {
}
@Override
@Nullable
public RegisterArg getInstanceArg() {
if (type != InvokeType.STATIC && getArgsCount() > 0) {
InsnArg firstArg = getArg(0);
if (firstArg.isRegister()) {
return ((RegisterArg) firstArg);
}
}
return null;
}
@Override
public InsnNode copy() {
return copyCommonParams(new InvokeNode(mth, type, getArgsCount()));
}
......
......@@ -77,7 +77,7 @@ public final class LiteralArg extends InsnArg {
@Override
public String toString() {
try {
String value = TypeGen.literalToString(literal, getType(), DEF_STRING_UTILS, true);
String value = TypeGen.literalToString(literal, getType(), DEF_STRING_UTILS, true, false);
if (getType().equals(ArgType.BOOLEAN) && (value.equals("true") || value.equals("false"))) {
return value;
}
......
......@@ -55,8 +55,7 @@ public class RegisterArg extends InsnArg implements Named {
if (sVar != null) {
return sVar.getTypeInfo().getType();
}
LOG.warn("Register type unknown, SSA variable not initialized: r{}", regNum);
return type;
return ArgType.UNKNOWN;
}
public ArgType getInitType() {
......
......@@ -63,6 +63,7 @@ public class ConstructorInsn extends InsnNode implements CallMthInterface {
return callMth;
}
@Override
public RegisterArg getInstanceArg() {
return instanceArg;
}
......
......@@ -40,6 +40,10 @@ public class ConstInlineVisitor extends AbstractVisitor {
if (mth.isNoCode()) {
return;
}
process(mth);
}
public static void process(MethodNode mth) {
List<InsnNode> toRemove = new ArrayList<>();
for (BlockNode block : mth.getBasicBlocks()) {
toRemove.clear();
......@@ -175,17 +179,19 @@ public class ConstInlineVisitor extends AbstractVisitor {
if (constArg.isLiteral()) {
long literal = ((LiteralArg) constArg).getLiteral();
ArgType argType = arg.getInitType();
ArgType argType = arg.getType();
if (argType == ArgType.UNKNOWN) {
argType = arg.getInitType();
}
if (argType.isObject() && literal != 0) {
argType = ArgType.NARROW_NUMBERS;
}
LiteralArg litArg = InsnArg.lit(literal, argType);
litArg.copyAttributesFrom(constArg);
if (!useInsn.replaceArg(arg, litArg)) {
return false;
}
// arg replaced, made some optimizations
litArg.setType(arg.getInitType());
FieldNode fieldNode = null;
ArgType litArgType = litArg.getType();
if (litArgType.isTypeKnown()) {
......
package jadx.core.dex.visitors;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.InvokeType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.regions.variables.ProcessVariables;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.exceptions.JadxException;
/**
* Remove primitives boxing
* i.e convert 'Integer.valueOf(1)' to '1'
*/
@JadxVisitor(
name = "DeboxingVisitor",
desc = "Remove primitives boxing",
runBefore = {
CodeShrinkVisitor.class,
ProcessVariables.class
}
)
public class DeboxingVisitor extends AbstractVisitor {
private Set<MethodInfo> valueOfMths;
@Override
public void init(RootNode root) {
valueOfMths = new HashSet<>();
valueOfMths.add(valueOfMth(root, ArgType.INT, "java.lang.Integer"));
valueOfMths.add(valueOfMth(root, ArgType.BOOLEAN, "java.lang.Boolean"));
valueOfMths.add(valueOfMth(root, ArgType.BYTE, "java.lang.Byte"));
valueOfMths.add(valueOfMth(root, ArgType.SHORT, "java.lang.Short"));
valueOfMths.add(valueOfMth(root, ArgType.CHAR, "java.lang.Character"));
valueOfMths.add(valueOfMth(root, ArgType.LONG, "java.lang.Long"));
}
private static MethodInfo valueOfMth(RootNode root, ArgType argType, String clsName) {
ArgType boxType = ArgType.object(clsName);
ClassInfo boxCls = ClassInfo.fromType(root, boxType);
return MethodInfo.externalMth(boxCls, "valueOf", Collections.singletonList(argType), boxType);
}
@Override
public void visit(MethodNode mth) throws JadxException {
if (mth.isNoCode()) {
return;
}
boolean replaced = false;
for (BlockNode blockNode : mth.getBasicBlocks()) {
List<InsnNode> insnList = blockNode.getInstructions();
int count = insnList.size();
for (int i = 0; i < count; i++) {
InsnNode insnNode = insnList.get(i);
if (insnNode.getType() == InsnType.INVOKE) {
InsnNode replaceInsn = checkForReplace(((InvokeNode) insnNode));
if (replaceInsn != null) {
BlockUtils.replaceInsn(blockNode, i, replaceInsn);
replaced = true;
}
}
}
}
if (replaced) {
ConstInlineVisitor.process(mth);
}
}
private InsnNode checkForReplace(InvokeNode insnNode) {
if (insnNode.getInvokeType() != InvokeType.STATIC
|| insnNode.getResult() == null) {
return null;
}
MethodInfo callMth = insnNode.getCallMth();
if (valueOfMths.contains(callMth)) {
RegisterArg resArg = insnNode.getResult();
InsnArg arg = insnNode.getArg(0);
if (arg.isLiteral() && checkArgUsage(resArg)) {
ArgType primitiveType = callMth.getArgumentsTypes().get(0);
ArgType boxType = callMth.getReturnType();
if (isNeedExplicitCast(resArg, primitiveType, boxType)) {
arg.add(AFlag.EXPLICIT_PRIMITIVE_TYPE);
}
resArg.setType(primitiveType);
arg.setType(primitiveType);
InsnNode constInsn = new InsnNode(InsnType.CONST, 1);
constInsn.addArg(arg);
constInsn.setResult(resArg);
return constInsn;
}
}
return null;
}
private boolean isNeedExplicitCast(RegisterArg resArg, ArgType primitiveType, ArgType boxType) {
if (primitiveType == ArgType.LONG) {
return true;
}
if (primitiveType != ArgType.INT) {
Set<ArgType> useTypes = collectUseTypes(resArg);
useTypes.add(resArg.getType());
useTypes.remove(boxType);
useTypes.remove(primitiveType);
return !useTypes.isEmpty();
}
return false;
}
private boolean checkArgUsage(RegisterArg arg) {
for (RegisterArg useArg : arg.getSVar().getUseList()) {
InsnNode parentInsn = useArg.getParentInsn();
if (parentInsn == null) {
return false;
}
if (parentInsn.getType() == InsnType.INVOKE) {
InvokeNode invokeNode = (InvokeNode) parentInsn;
if (useArg.equals(invokeNode.getInstanceArg())) {
return false;
}
}
}
return true;
}
private Set<ArgType> collectUseTypes(RegisterArg arg) {
Set<ArgType> types = new HashSet<>();
for (RegisterArg useArg : arg.getSVar().getUseList()) {
types.add(useArg.getType());
types.add(useArg.getInitType());
}
return types;
}
}
......@@ -37,7 +37,7 @@ public class TestVarArg extends IntegrationTest {
assertThat(code, containsString("void test2(int i, Object... a) {"));
assertThat(code, containsString("test1(1, 2);"));
assertThat(code, containsString("test2(3, \"1\", Integer.valueOf(7));"));
assertThat(code, containsString("test2(3, \"1\", 7);"));
// negative case
assertThat(code, containsString("void test3(int[] a) {"));
......
package jadx.tests.integration.others;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
public class TestDeboxing extends IntegrationTest {
public static class TestCls {
public Object testInt() {
return 1;
}
public Object testBoolean() {
return true;
}
public Object testByte() {
return (byte) 2;
}
public Short testShort() {
return 3;
}
public Character testChar() {
return 'c';
}
public Long testLong() {
return 4L;
}
public void testConstInline() {
Boolean v = true;
use(v);
use(v);
}
private void use(Boolean v) {
}
public void check() {
// don't mind weird comparisons
// need to get primitive without using boxing or literal
// otherwise will get same result after decompilation
assertThat(testInt(), is(Integer.sum(0, 1)));
assertThat(testBoolean(), is(Boolean.TRUE));
assertThat(testByte(), is(Byte.parseByte("2")));
assertThat(testShort(), is(Short.parseShort("3")));
assertThat(testChar(), is("c".charAt(0)));
assertThat(testLong(), is(Long.valueOf("4")));
testConstInline();
}
}
@Test
public void test() {
noDebugInfo();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("return 1;"));
assertThat(code, containsOne("return true;"));
assertThat(code, containsOne("return (byte) 2;"));
assertThat(code, containsOne("return 3;"));
assertThat(code, containsOne("return 'c';"));
assertThat(code, containsOne("return 4L;"));
assertThat(code, countString(2, "use(true);"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment