Commit b61daaed authored by Skylot's avatar Skylot

core: fix synchronized block processing

parent c6f0c89c
......@@ -71,6 +71,7 @@ public class Jadx {
passes.add(new RegionMakerVisitor());
passes.add(new TernaryVisitor());
passes.add(new CodeShrinker());
passes.add(new SimplifyVisitor());
passes.add(new ProcessVariables());
passes.add(new CheckRegions());
......
......@@ -184,7 +184,7 @@ public class RegionGen extends InsnGen {
private void makeSynchronizedRegion(SynchronizedRegion cont, CodeWriter code) throws CodegenException {
code.startLine("synchronized (");
addArg(code, cont.getInsn().getArg(0));
addArg(code, cont.getEnterInsn().getArg(0));
code.add(") {");
makeRegionIndent(code, cont.getRegion());
code.startLine('}');
......
......@@ -69,6 +69,9 @@ public abstract class InsnArg extends Typed {
public InsnArg wrapInstruction(InsnNode insn) {
InsnNode parent = parentInsn;
if (parent == null) {
return null;
}
assert parent != insn : "Can't wrap instruction info itself";
int count = parent.getArgsCount();
for (int i = 0; i < count; i++) {
......
package jadx.core.dex.instructions.args;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
public class TypedVar {
......@@ -38,10 +39,6 @@ public class TypedVar {
}
}
public List<InsnArg> getUseList() {
return useList;
}
public String getName() {
return name;
}
......@@ -50,6 +47,20 @@ public class TypedVar {
this.name = name;
}
public List<InsnArg> getUseList() {
return useList;
}
public void removeUse(InsnArg arg) {
Iterator<InsnArg> it = useList.iterator();
while (it.hasNext()) {
InsnArg use = it.next();
if (use == arg) {
it.remove();
}
}
}
public void mergeName(TypedVar arg) {
String argName = arg.getName();
if (argName != null) {
......
......@@ -18,6 +18,7 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.parser.DebugInfoParser;
import jadx.core.dex.regions.Region;
import jadx.core.dex.trycatch.ExcHandlerAttr;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.TryCatchBlock;
......@@ -60,7 +61,7 @@ public class MethodNode extends LineAttrNode implements ILoadable {
private BlockNode enterBlock;
private List<BlockNode> exitBlocks;
private IContainer region;
private Region region;
private List<ExceptionHandler> exceptionHandlers;
private List<LoopAttr> loops = Collections.emptyList();
......@@ -505,11 +506,11 @@ public class MethodNode extends LineAttrNode implements ILoadable {
return accFlags;
}
public IContainer getRegion() {
public Region getRegion() {
return region;
}
public void setRegion(IContainer region) {
public void setRegion(Region region) {
this.region = region;
}
......
......@@ -4,21 +4,27 @@ import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import java.util.LinkedList;
import java.util.List;
public final class SynchronizedRegion extends AbstractRegion {
private final InsnNode insn;
private final InsnNode enterInsn;
private final List<InsnNode> exitInsns = new LinkedList<InsnNode>();
private final Region region;
public SynchronizedRegion(IRegion parent, InsnNode insn) {
super(parent);
this.insn = insn;
this.enterInsn = insn;
this.region = new Region(this);
}
public InsnNode getInsn() {
return insn;
public InsnNode getEnterInsn() {
return enterInsn;
}
public List<InsnNode> getExitInsns() {
return exitInsns;
}
public Region getRegion() {
......
......@@ -102,7 +102,7 @@ public class CodeShrinker extends AbstractVisitor {
if (from > to) {
throw new JadxRuntimeException("Invalid inline insn positions: " + from + " - " + to);
}
for (int i = from; i < to - 1; i++) {
for (int i = from; i < to; i++) {
ArgsInfo argsInfo = argsList.get(i);
if (argsInfo.getInlinedInsn() == this) {
continue;
......
......@@ -328,7 +328,11 @@ public class RegionMaker {
Set<BlockNode> exits = new HashSet<BlockNode>();
cacheSet.clear();
traverseMonitorExits(insn.getArg(0), block, exits, cacheSet);
traverseMonitorExits(synchRegion, insn.getArg(0), block, exits, cacheSet);
for (InsnNode exitInsn : synchRegion.getExitInsns()) {
InstructionRemover.unbindInsn(exitInsn);
}
block = BlockUtils.getNextBlock(block);
BlockNode exit;
......@@ -337,7 +341,6 @@ public class RegionMaker {
} else {
cacheSet.clear();
exit = traverseMonitorExitsCross(block, exits, cacheSet);
// LOG.debug("synchronized exits: " + exits + ", cross: " + exit);
}
stack.push(synchRegion);
......@@ -350,19 +353,20 @@ public class RegionMaker {
/**
* Traverse from monitor-enter thru successors and collect blocks contains monitor-exit
*/
private static void traverseMonitorExits(InsnArg arg, BlockNode block, Set<BlockNode> exits, Set<BlockNode> visited) {
private static void traverseMonitorExits(SynchronizedRegion region, InsnArg arg, BlockNode block,
Set<BlockNode> exits, Set<BlockNode> visited) {
visited.add(block);
for (InsnNode insn : block.getInstructions()) {
if (insn.getType() == InsnType.MONITOR_EXIT
&& insn.getArg(0).equals(arg)) {
exits.add(block);
InstructionRemover.remove(block, insn);
region.getExitInsns().add(insn);
return;
}
}
for (BlockNode node : block.getCleanSuccessors()) {
if (!visited.contains(node)) {
traverseMonitorExits(arg, node, exits, visited);
traverseMonitorExits(region, arg, node, exits, visited);
}
}
}
......
......@@ -4,20 +4,27 @@ import jadx.core.dex.attributes.AttributeFlag;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.IfRegion;
import jadx.core.dex.regions.LoopRegion;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.SynchronizedRegion;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.InstructionRemover;
import jadx.core.utils.exceptions.JadxException;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Pack blocks into regions for code generation
*/
public class RegionMakerVisitor extends AbstractVisitor {
private static final Logger LOG = LoggerFactory.getLogger(RegionMakerVisitor.class);
@Override
public void visit(MethodNode mth) throws JadxException {
......@@ -73,6 +80,11 @@ public class RegionMakerVisitor extends AbstractVisitor {
if (mth.getReturnType().equals(ArgType.VOID)) {
DepthRegionTraverser.traverseAll(mth, new ProcessReturnInsns());
}
if (mth.getAccessFlags().isSynchronized()) {
removeSynchronized(mth);
}
}
private static void processIfRegion(IfRegion ifRegion) {
......@@ -94,4 +106,28 @@ public class RegionMakerVisitor extends AbstractVisitor {
}
}
}
private static void removeSynchronized(MethodNode mth) {
Region startRegion = mth.getRegion();
List<IContainer> subBlocks = startRegion.getSubBlocks();
if (!subBlocks.isEmpty() && subBlocks.get(0) instanceof SynchronizedRegion) {
SynchronizedRegion synchRegion = (SynchronizedRegion) subBlocks.get(0);
InsnNode synchInsn = synchRegion.getEnterInsn();
if (!synchInsn.getArg(0).isThis()) {
LOG.warn("In synchronized method {}, top region not synchronized by 'this' {}", mth, synchInsn);
return;
}
// replace synchronized block with inner region
startRegion.getSubBlocks().set(0, synchRegion.getRegion());
// remove 'monitor-enter' instruction
InstructionRemover.remove(mth, synchInsn);
// remove 'monitor-exit' instruction
for (InsnNode exit : synchRegion.getExitInsns()) {
InstructionRemover.remove(mth, exit);
}
// run region cleaner again
CleanRegions.process(mth);
// assume that CodeShrinker will be run after this
}
}
}
......@@ -3,6 +3,7 @@ package jadx.core.utils;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import java.util.ArrayList;
import java.util.Iterator;
......@@ -40,11 +41,11 @@ public class InstructionRemover {
public static void unbindInsn(InsnNode insn) {
if (insn.getResult() != null) {
InsnArg res = insn.getResult();
res.getTypedVar().getUseList().remove(res);
res.getTypedVar().removeUse(res);
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isRegister()) {
arg.getTypedVar().getUseList().remove(arg);
arg.getTypedVar().removeUse(arg);
}
}
}
......@@ -75,10 +76,18 @@ public class InstructionRemover {
}
}
public static void remove(MethodNode mth, InsnNode insn) {
BlockNode block = BlockUtils.getBlockByInsn(mth, insn);
if (block != null) {
remove(block, insn);
}
}
public static void remove(BlockNode block, InsnNode insn) {
unbindInsn(insn);
// remove by pointer (don't use equals)
for (Iterator<InsnNode> it = block.getInstructions().iterator(); it.hasNext(); ) {
Iterator<InsnNode> it = block.getInstructions().iterator();
while (it.hasNext()) {
InsnNode ir = it.next();
if (ir == insn) {
it.remove();
......
package jadx.tests.internal;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestSynchronized extends InternalJadxTest {
public static class TestCls {
public boolean f = false;
public final Object o = new Object();
public int i = 7;
public synchronized boolean test1() {
return this.f;
}
public int test2() {
synchronized (this.o) {
return i;
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, not(containsString("synchronized (this) {")));
assertThat(code, containsString("public synchronized boolean test1() {"));
assertThat(code, containsString("return this.f"));
assertThat(code, containsString("synchronized (this.o) {"));
}
}
package jadx.tests.internal.inline;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestInline6 extends InternalJadxTest {
public static class TestCls {
public void f() {
}
public void test(int a, int b) {
long start = System.nanoTime();
f();
System.out.println(System.nanoTime() - start);
}
}
@Test
public void test() {
setOutputCFG();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsString("System.out.println(System.nanoTime() - start);"));
assertThat(code, not(containsString("System.out.println(System.nanoTime() - System.nanoTime());")));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment