Commit cf79a519 authored by Skylot's avatar Skylot

refactor: move code shrink visitor to separate package and extract inner classes

parent d0699286
......@@ -11,7 +11,7 @@ import org.slf4j.LoggerFactory;
import jadx.api.JadxArgs;
import jadx.core.dex.visitors.ClassModifier;
import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.dex.visitors.ConstInlineVisitor;
import jadx.core.dex.visitors.ConstructorVisitor;
import jadx.core.dex.visitors.DependencyCollector;
......@@ -81,7 +81,7 @@ public class Jadx {
passes.add(new DebugInfoApplyVisitor());
passes.add(new ModVisitor());
passes.add(new CodeShrinker());
passes.add(new CodeShrinkVisitor());
passes.add(new ReSugarCode());
if (args.isCfgOutput()) {
passes.add(DotGraphVisitor.dump());
......@@ -92,7 +92,7 @@ public class Jadx {
passes.add(new ReturnVisitor());
passes.add(new CleanRegions());
passes.add(new CodeShrinker());
passes.add(new CodeShrinkVisitor());
passes.add(new SimplifyVisitor());
passes.add(new CheckRegions());
......
......@@ -26,6 +26,7 @@ import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.ErrorsCounter;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.exceptions.JadxException;
......@@ -33,7 +34,7 @@ import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "EnumVisitor",
desc = "Restore enum classes",
runAfter = {CodeShrinker.class, ModVisitor.class}
runAfter = {CodeShrinkVisitor.class, ModVisitor.class}
)
public class EnumVisitor extends AbstractVisitor {
......
......@@ -21,6 +21,7 @@ import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
......@@ -77,7 +78,7 @@ public class MethodInlineVisitor extends AbstractVisitor {
&& get.getResult().equalRegisterAndType((RegisterArg) retArg)) {
RegisterArg retReg = (RegisterArg) retArg;
retReg.getSVar().removeUse(retReg);
CodeShrinker.shrinkMethod(mth);
CodeShrinkVisitor.shrinkMethod(mth);
insnList = firstBlock.getInstructions();
if (insnList.size() == 1) {
......
......@@ -37,6 +37,7 @@ import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.trycatch.ExcHandlerAttr;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.ErrorsCounter;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.InstructionRemover;
......@@ -51,7 +52,7 @@ import static jadx.core.utils.BlockUtils.replaceInsn;
@JadxVisitor(
name = "ModVisitor",
desc = "Modify method instructions",
runBefore = CodeShrinker.class
runBefore = CodeShrinkVisitor.class
)
public class ModVisitor extends AbstractVisitor {
private static final Logger LOG = LoggerFactory.getLogger(ModVisitor.class);
......
......@@ -15,6 +15,7 @@ import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.regions.variables.ProcessVariables;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.exceptions.JadxException;
/**
......@@ -25,7 +26,7 @@ import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "PrepareForCodeGen",
desc = "Prepare instructions for code generation pass",
runAfter = {CodeShrinker.class, ClassModifier.class, ProcessVariables.class}
runAfter = {CodeShrinkVisitor.class, ClassModifier.class, ProcessVariables.class}
)
public class PrepareForCodeGen extends AbstractVisitor {
......
......@@ -30,6 +30,7 @@ import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.InstructionRemover;
......@@ -39,7 +40,7 @@ import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "ReSugarCode",
desc = "Simplify synthetic or verbose code",
runAfter = CodeShrinker.class
runAfter = CodeShrinkVisitor.class
)
public class ReSugarCode extends AbstractVisitor {
......
......@@ -33,7 +33,7 @@ import jadx.core.dex.regions.loops.ForLoop;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.regions.loops.LoopType;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.regions.variables.ProcessVariables;
import jadx.core.utils.BlockUtils;
......@@ -206,7 +206,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
if (arrayArg.isRegister()) {
((RegisterArg) arrayArg).getSVar().removeUse((RegisterArg) arrGetInsn.getArg(0));
}
CodeShrinker.shrinkMethod(mth);
CodeShrinkVisitor.shrinkMethod(mth);
len.add(AFlag.DONT_GENERATE);
if (arrGetInsn.contains(AFlag.WRAPPED)) {
......
......@@ -17,7 +17,7 @@ import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.visitors.CodeShrinker;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.InsnList;
public class TernaryMod {
......@@ -89,7 +89,7 @@ public class TernaryMod {
header.getInstructions().add(ternInsn);
// shrink method again
CodeShrinker.shrinkMethod(mth);
CodeShrinkVisitor.shrinkMethod(mth);
return true;
}
......@@ -120,7 +120,7 @@ public class TernaryMod {
header.getInstructions().add(retInsn);
header.add(AFlag.RETURN);
CodeShrinker.shrinkMethod(mth);
CodeShrinkVisitor.shrinkMethod(mth);
return true;
}
return false;
......
package jadx.core.dex.visitors.shrink;
import java.util.BitSet;
import java.util.LinkedList;
import java.util.List;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.EmptyBitSet;
import jadx.core.utils.exceptions.JadxRuntimeException;
final class ArgsInfo {
private final InsnNode insn;
private final List<ArgsInfo> argsList;
private final List<RegisterArg> args;
private final int pos;
private int inlineBorder;
private ArgsInfo inlinedInsn;
public ArgsInfo(InsnNode insn, List<ArgsInfo> argsList, int pos) {
this.insn = insn;
this.argsList = argsList;
this.pos = pos;
this.inlineBorder = pos;
this.args = getArgs(insn);
}
public static List<RegisterArg> getArgs(InsnNode insn) {
List<RegisterArg> args = new LinkedList<>();
addArgs(insn, args);
return args;
}
private static void addArgs(InsnNode insn, List<RegisterArg> args) {
if (insn.getType() == InsnType.CONSTRUCTOR) {
args.add(((ConstructorInsn) insn).getInstanceArg());
} else if (insn.getType() == InsnType.TERNARY) {
args.addAll(((TernaryInsn) insn).getCondition().getRegisterArgs());
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isRegister()) {
args.add((RegisterArg) arg);
}
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isInsnWrap()) {
addArgs(((InsnWrapArg) arg).getWrapInsn(), args);
}
}
}
public InsnNode getInsn() {
return insn;
}
List<RegisterArg> getArgs() {
return args;
}
public WrapInfo checkInline(int assignPos, RegisterArg arg) {
if (assignPos >= inlineBorder || !canMove(assignPos, inlineBorder)) {
return null;
}
inlineBorder = assignPos;
return inline(assignPos, arg);
}
private boolean canMove(int from, int to) {
ArgsInfo startInfo = argsList.get(from);
List<RegisterArg> movedArgs = startInfo.getArgs();
int start = from + 1;
if (start == to) {
// previous instruction or on edge of inline border
return true;
}
if (start > to) {
throw new JadxRuntimeException("Invalid inline insn positions: " + start + " - " + to);
}
BitSet movedSet;
if (movedArgs.isEmpty()) {
if (startInfo.insn.isConstInsn()) {
return true;
}
movedSet = EmptyBitSet.EMPTY;
} else {
movedSet = new BitSet();
for (RegisterArg arg : movedArgs) {
movedSet.set(arg.getRegNum());
}
}
for (int i = start; i < to; i++) {
ArgsInfo argsInfo = argsList.get(i);
if (argsInfo.getInlinedInsn() == this) {
continue;
}
InsnNode curInsn = argsInfo.insn;
if (!curInsn.canReorder() || usedArgAssign(curInsn, movedSet)) {
return false;
}
}
return true;
}
static boolean usedArgAssign(InsnNode insn, BitSet args) {
if (args.isEmpty()) {
return false;
}
RegisterArg result = insn.getResult();
if (result == null || result.isField()) {
return false;
}
return args.get(result.getRegNum());
}
WrapInfo inline(int assignInsnPos, RegisterArg arg) {
ArgsInfo argsInfo = argsList.get(assignInsnPos);
argsInfo.inlinedInsn = this;
return new WrapInfo(argsInfo.insn, arg);
}
ArgsInfo getInlinedInsn() {
if (inlinedInsn != null) {
ArgsInfo parent = inlinedInsn.getInlinedInsn();
if (parent != null) {
inlinedInsn = parent;
}
}
return inlinedInsn;
}
@Override
public String toString() {
return "ArgsInfo: |" + inlineBorder
+ " ->" + (inlinedInsn == null ? "-" : inlinedInsn.pos)
+ " " + args + " : " + insn;
}
}
package jadx.core.dex.visitors;
package jadx.core.dex.visitors.shrink;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
......@@ -13,17 +12,22 @@ import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.ModVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.EmptyBitSet;
import jadx.core.utils.InsnList;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class CodeShrinker extends AbstractVisitor {
@JadxVisitor(
name = "CodeShrinkVisitor",
desc = "Inline variables for make code smaller",
runAfter = {ModVisitor.class}
)
public class CodeShrinkVisitor extends AbstractVisitor {
@Override
public void visit(MethodNode mth) {
......@@ -40,156 +44,6 @@ public class CodeShrinker extends AbstractVisitor {
}
}
private static final class ArgsInfo {
private final InsnNode insn;
private final List<ArgsInfo> argsList;
private final List<RegisterArg> args;
private final int pos;
private int inlineBorder;
private ArgsInfo inlinedInsn;
public ArgsInfo(InsnNode insn, List<ArgsInfo> argsList, int pos) {
this.insn = insn;
this.argsList = argsList;
this.pos = pos;
this.inlineBorder = pos;
this.args = getArgs(insn);
}
public static List<RegisterArg> getArgs(InsnNode insn) {
List<RegisterArg> args = new LinkedList<>();
addArgs(insn, args);
return args;
}
private static void addArgs(InsnNode insn, List<RegisterArg> args) {
if (insn.getType() == InsnType.CONSTRUCTOR) {
args.add(((ConstructorInsn) insn).getInstanceArg());
} else if (insn.getType() == InsnType.TERNARY) {
args.addAll(((TernaryInsn) insn).getCondition().getRegisterArgs());
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isRegister()) {
args.add((RegisterArg) arg);
}
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isInsnWrap()) {
addArgs(((InsnWrapArg) arg).getWrapInsn(), args);
}
}
}
public InsnNode getInsn() {
return insn;
}
private List<RegisterArg> getArgs() {
return args;
}
public WrapInfo checkInline(int assignPos, RegisterArg arg) {
if (assignPos >= inlineBorder || !canMove(assignPos, inlineBorder)) {
return null;
}
inlineBorder = assignPos;
return inline(assignPos, arg);
}
private boolean canMove(int from, int to) {
ArgsInfo startInfo = argsList.get(from);
List<RegisterArg> movedArgs = startInfo.getArgs();
int start = from + 1;
if (start == to) {
// previous instruction or on edge of inline border
return true;
}
if (start > to) {
throw new JadxRuntimeException("Invalid inline insn positions: " + start + " - " + to);
}
BitSet movedSet;
if (movedArgs.isEmpty()) {
if (startInfo.insn.isConstInsn()) {
return true;
}
movedSet = EmptyBitSet.EMPTY;
} else {
movedSet = new BitSet();
for (RegisterArg arg : movedArgs) {
movedSet.set(arg.getRegNum());
}
}
for (int i = start; i < to; i++) {
ArgsInfo argsInfo = argsList.get(i);
if (argsInfo.getInlinedInsn() == this) {
continue;
}
InsnNode curInsn = argsInfo.insn;
if (!curInsn.canReorder() || usedArgAssign(curInsn, movedSet)) {
return false;
}
}
return true;
}
private static boolean usedArgAssign(InsnNode insn, BitSet args) {
if (args.isEmpty()) {
return false;
}
RegisterArg result = insn.getResult();
if (result == null || result.isField()) {
return false;
}
return args.get(result.getRegNum());
}
public WrapInfo inline(int assignInsnPos, RegisterArg arg) {
ArgsInfo argsInfo = argsList.get(assignInsnPos);
argsInfo.inlinedInsn = this;
return new WrapInfo(argsInfo.insn, arg);
}
public ArgsInfo getInlinedInsn() {
if (inlinedInsn != null) {
ArgsInfo parent = inlinedInsn.getInlinedInsn();
if (parent != null) {
inlinedInsn = parent;
}
}
return inlinedInsn;
}
@Override
public String toString() {
return "ArgsInfo: |" + inlineBorder
+ " ->" + (inlinedInsn == null ? "-" : inlinedInsn.pos)
+ " " + args + " : " + insn;
}
}
private static final class WrapInfo {
private final InsnNode insn;
private final RegisterArg arg;
public WrapInfo(InsnNode assignInsn, RegisterArg arg) {
this.insn = assignInsn;
this.arg = arg;
}
private InsnNode getInsn() {
return insn;
}
private RegisterArg getArg() {
return arg;
}
@Override
public String toString() {
return "WrapInfo: " + arg + " -> " + insn;
}
}
private static void shrinkBlock(MethodNode mth, BlockNode block) {
if (block.getInstructions().isEmpty()) {
return;
......
package jadx.core.dex.visitors.shrink;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.InsnNode;
final class WrapInfo {
private final InsnNode insn;
private final RegisterArg arg;
WrapInfo(InsnNode assignInsn, RegisterArg arg) {
this.insn = assignInsn;
this.arg = arg;
}
InsnNode getInsn() {
return insn;
}
RegisterArg getArg() {
return arg;
}
@Override
public String toString() {
return "WrapInfo: " + arg + " -> " + insn;
}
}
......@@ -40,16 +40,19 @@ public class JadxVisitorsOrderTest {
List<String> errors = new ArrayList<>();
Set<String> names = new HashSet<>();
Set<Class> passClsSet = new HashSet<>();
for (int i = 0; i < passes.size(); i++) {
IDexTreeVisitor pass = passes.get(i);
JadxVisitor info = pass.getClass().getAnnotation(JadxVisitor.class);
Class<? extends IDexTreeVisitor> passClass = pass.getClass();
JadxVisitor info = passClass.getAnnotation(JadxVisitor.class);
if (info == null) {
LOG.warn("No JadxVisitor annotation for visitor: {}", pass.getClass().getName());
LOG.warn("No JadxVisitor annotation for visitor: {}", passClass.getName());
continue;
}
String passName = pass.getClass().getSimpleName();
if (!names.add(passName)) {
errors.add("Visitor name conflict: " + passName + ", class: " + pass.getClass().getName());
boolean firstOccurrence = passClsSet.add(passClass);
String passName = passClass.getSimpleName();
if (firstOccurrence && !names.add(passName)) {
errors.add("Visitor name conflict: " + passName + ", class: " + passClass.getName());
}
for (Class<? extends IDexTreeVisitor> cls : info.runBefore()) {
if (classList.indexOf(cls) < i) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment