Rewrites some if statements into boolean expressions.
if (<cond>) { <expr_stmt>; } => <cond> && <expr_stmt>
if (<cond>) { <then_expr_stmt>; } else { <else_expr_stmt>; } =>
<cond> ? <then_expr_stmt> : <else_expr_stmt>
Also supports lifting return statement:
if (<cond>) { return <expr1>; } else { return <expr2>; } =>
return <cond>?<expr1>:<expr2>;
Reduces code size by a tiny amout (approx. 0.5%) mostly because this optimization create a lot of opportunities for inliner to kick in.
Patch by: mike.aizatsky
Review by: me, spoon
git-svn-id: https://google-web-toolkit.googlecode.com/svn/trunk@6533 8db76d5a-ed1c-0410-87a9-c151d255dfc7
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java b/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
index ee3e874..90505e7 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
@@ -89,6 +89,8 @@
*/
public class DeadCodeVisitor extends JModVisitor {
+ private JMethod currentMethod = null;
+
/**
* Expressions whose result does not matter. A parent node should add any
* children whose result does not matter to this set during the parent's
@@ -358,7 +360,7 @@
@Override
public void endVisit(JIfStatement x, Context ctx) {
JStatement updated = simplifier.ifStatement(x, x.getSourceInfo(),
- x.getIfExpr(), x.getThenStmt(), x.getElseStmt());
+ x.getIfExpr(), x.getThenStmt(), x.getElseStmt(), currentMethod);
if (updated != x) {
ctx.replaceMe(updated);
}
@@ -373,6 +375,11 @@
}
}
+ @Override
+ public void endVisit(JMethod x, Context ctx) {
+ currentMethod = null;
+ }
+
/**
* Resolve method calls that can be computed statically.
*/
@@ -596,6 +603,12 @@
}
@Override
+ public boolean visit(JMethod x, Context ctx) {
+ currentMethod = x;
+ return true;
+ }
+
+ @Override
public boolean visit(JMethodCall x, Context ctx) {
JMethod target = x.getTarget();
if (target.isStatic() && x.getInstance() != null) {
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java b/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
index 8e4f00e..d5929ac 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
@@ -17,6 +17,7 @@
import com.google.gwt.dev.jjs.ast.Context;
import com.google.gwt.dev.jjs.ast.JBinaryOperation;
+import com.google.gwt.dev.jjs.ast.JBinaryOperator;
import com.google.gwt.dev.jjs.ast.JConditional;
import com.google.gwt.dev.jjs.ast.JDeclarationStatement;
import com.google.gwt.dev.jjs.ast.JExpression;
@@ -56,14 +57,21 @@
JType lhsType = x.getLhs().getType();
JType rhsType = x.getRhs().getType();
JType resultType = x.getType();
+ JBinaryOperator op = x.getOp();
if (resultType == program.getTypeJavaLangString()) {
// Don't mess with concat.
return;
}
+ if (lhsType == JPrimitiveType.BOOLEAN
+ && (op == JBinaryOperator.AND || op == JBinaryOperator.OR)) {
+ // Don't mess with if rewriter.
+ return;
+ }
+
// Special case: shift operators always coerce a long RHS to int.
- if (x.getOp().isShiftOperator()) {
+ if (op.isShiftOperator()) {
if (rhsType == longType) {
rhsType = program.getTypePrimitiveInt();
}
@@ -80,7 +88,7 @@
if ((lhsType == floatType || lhsType == doubleType)) {
coerceTo = lhsType;
}
- if (x.getOp().isAssignment()) {
+ if (op.isAssignment()) {
// In an assignment, the lhs must coerce the rhs
coerceTo = lhsType;
} else if ((rhsType == floatType || rhsType == doubleType)) {
@@ -93,7 +101,7 @@
JExpression newRhs = checkAndReplace(x.getRhs(), rhsType);
if (newLhs != x.getLhs() || newRhs != x.getRhs()) {
JBinaryOperation binOp = new JBinaryOperation(x.getSourceInfo(),
- resultType, x.getOp(), newLhs, newRhs);
+ resultType, op, newLhs, newRhs);
ctx.replaceMe(binOp);
}
}
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/LongEmulationNormalizer.java b/dev/core/src/com/google/gwt/dev/jjs/impl/LongEmulationNormalizer.java
index 5019920..d35dbc9 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/LongEmulationNormalizer.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/LongEmulationNormalizer.java
@@ -56,7 +56,6 @@
JType lhsType = x.getLhs().getType();
JType rhsType = x.getRhs().getType();
if (lhsType != longType) {
- assert (rhsType != longType);
return;
}
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java b/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java
index e376f3b..de4e38b 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java
@@ -23,10 +23,13 @@
import com.google.gwt.dev.jjs.ast.JCastOperation;
import com.google.gwt.dev.jjs.ast.JConditional;
import com.google.gwt.dev.jjs.ast.JExpression;
+import com.google.gwt.dev.jjs.ast.JExpressionStatement;
import com.google.gwt.dev.jjs.ast.JIfStatement;
+import com.google.gwt.dev.jjs.ast.JMethod;
import com.google.gwt.dev.jjs.ast.JPrefixOperation;
import com.google.gwt.dev.jjs.ast.JPrimitiveType;
import com.google.gwt.dev.jjs.ast.JProgram;
+import com.google.gwt.dev.jjs.ast.JReturnStatement;
import com.google.gwt.dev.jjs.ast.JStatement;
import com.google.gwt.dev.jjs.ast.JType;
import com.google.gwt.dev.jjs.ast.JUnaryOperation;
@@ -186,7 +189,8 @@
}
public JStatement ifStatement(JIfStatement original, SourceInfo sourceInfo,
- JExpression condExpr, JStatement thenStmt, JStatement elseStmt) {
+ JExpression condExpr, JStatement thenStmt, JStatement elseStmt,
+ JMethod currentMethod) {
if (condExpr instanceof JMultiExpression) {
// if(a,b,c) d else e -> {a; b; if(c) d else e; }
JMultiExpression condMulti = (JMultiExpression) condExpr;
@@ -195,7 +199,7 @@
newBlock.addStmt(expr.makeStatement());
}
newBlock.addStmt(ifStatement(null, sourceInfo, last(condMulti.exprs),
- thenStmt, elseStmt));
+ thenStmt, elseStmt, currentMethod));
// TODO(spoon): immediately simplify the resulting block
return newBlock;
}
@@ -227,10 +231,17 @@
// TODO: this goes away when we normalize the Java AST properly.
thenStmt = ensureBlock(thenStmt);
elseStmt = ensureBlock(elseStmt);
- return ifStatement(null, sourceInfo, unflipped, elseStmt, thenStmt);
+ return ifStatement(null, sourceInfo, unflipped, elseStmt, thenStmt,
+ currentMethod);
}
}
+ JStatement rewritenStatement = rewriteIfIntoBoolean(sourceInfo, condExpr,
+ thenStmt, elseStmt, currentMethod);
+ if (rewritenStatement != null) {
+ return rewritenStatement;
+ }
+
// no simplification made
if (original != null) {
return original;
@@ -304,4 +315,84 @@
}
return stmt;
}
+
+ private JExpression extractExpression(JStatement stmt) {
+ if (stmt instanceof JExpressionStatement) {
+ JExpressionStatement statement = (JExpressionStatement) stmt;
+ return statement.getExpr();
+ }
+
+ return null;
+ }
+
+ private JStatement extractSingleStatement(JStatement stmt) {
+ if (stmt instanceof JBlock) {
+ JBlock block = (JBlock) stmt;
+ if (block.getStatements().size() == 1) {
+ return extractSingleStatement(block.getStatements().get(0));
+ }
+ }
+
+ return stmt;
+ }
+
+ private JStatement rewriteIfIntoBoolean(SourceInfo sourceInfo,
+ JExpression condExpr, JStatement thenStmt, JStatement elseStmt,
+ JMethod currentMethod) {
+ thenStmt = extractSingleStatement(thenStmt);
+ elseStmt = extractSingleStatement(elseStmt);
+
+ if (thenStmt instanceof JReturnStatement
+ && elseStmt instanceof JReturnStatement && currentMethod != null) {
+ // Special case
+ // if () { return ..; } else { return ..; } =>
+ // return ... ? ... : ...;
+ JExpression thenExpression = ((JReturnStatement) thenStmt).getExpr();
+ JExpression elseExpression = ((JReturnStatement) elseStmt).getExpr();
+ if (thenExpression == null || elseExpression == null) {
+ // empty returns are not supported.
+ return null;
+ }
+
+ JConditional conditional = new JConditional(sourceInfo,
+ currentMethod.getType(), condExpr, thenExpression, elseExpression);
+
+ JReturnStatement returnStatement = new JReturnStatement(sourceInfo,
+ conditional);
+ return returnStatement;
+ }
+
+ if (elseStmt != null) {
+ // if () { } else { } -> ... ? ... : ... ;
+ JExpression thenExpression = extractExpression(thenStmt);
+ JExpression elseExpression = extractExpression(elseStmt);
+
+ if (thenExpression != null && elseExpression != null) {
+ JConditional conditional = new JConditional(sourceInfo,
+ JPrimitiveType.VOID, condExpr, thenExpression, elseExpression);
+
+ return conditional.makeStatement();
+ }
+ } else {
+ // if () { } -> ... && ...;
+ JExpression thenExpression = extractExpression(thenStmt);
+
+ if (thenExpression != null) {
+ JBinaryOperator binaryOperator = JBinaryOperator.AND;
+
+ JExpression unflipExpression = maybeUnflipBoolean(condExpr);
+ if (unflipExpression != null) {
+ condExpr = unflipExpression;
+ binaryOperator = JBinaryOperator.OR;
+ }
+
+ JBinaryOperation binaryOperation = new JBinaryOperation(sourceInfo,
+ program.getTypeVoid(), binaryOperator, condExpr, thenExpression);
+
+ return binaryOperation.makeStatement();
+ }
+ }
+
+ return null;
+ }
}
diff --git a/dev/core/test/com/google/gwt/dev/jjs/impl/DeadCodeEliminationTest.java b/dev/core/test/com/google/gwt/dev/jjs/impl/DeadCodeEliminationTest.java
index b26f401..02b1f2c 100644
--- a/dev/core/test/com/google/gwt/dev/jjs/impl/DeadCodeEliminationTest.java
+++ b/dev/core/test/com/google/gwt/dev/jjs/impl/DeadCodeEliminationTest.java
@@ -39,26 +39,38 @@
expected = getMainMethodSource(program);
assertEquals(userCode, expected, optimized);
}
+
+ /**
+ * Compare without compiling expected, needed when optimizations produce
+ * incorrect java code (e.g. "a" || "b" is incorrect in java).
+ */
+ public void intoString(String expected) {
+ String actual = optimized;
+ assertTrue(actual.startsWith("{"));
+ assertTrue(actual.endsWith("}"));
+ actual = actual.substring(1, actual.length() - 2).trim();
+ // Join lines in actual code and remove indentations
+ actual = actual.replaceAll(" +", " ").replaceAll("\n", "");
+ assertEquals(userCode, expected, actual);
+ }
+ }
+
+ @Override
+ public void setUp() throws Exception {
+ addSnippetClassDecl("static volatile boolean b;");
+ addSnippetClassDecl("static volatile boolean b1;");
+ addSnippetClassDecl("static volatile int i;");
+ addSnippetClassDecl("static volatile long l;");
}
public void testConditionalOptimizations() throws Exception {
optimize("int", "return true ? 3 : 4;").into("return 3;");
optimize("int", "return false ? 3 : 4;").into("return 4;");
- addSnippetClassDecl("static volatile boolean TEST");
- addSnippetClassDecl("static volatile boolean RESULT");
-
- optimize("boolean", "return TEST ? true : RESULT;").into(
- "return TEST || RESULT;");
-
- optimize("boolean", "return TEST ? false : RESULT;").into(
- "return !TEST && RESULT;");
-
- optimize("boolean", "return TEST ? RESULT : true;").into(
- "return !TEST || RESULT;");
-
- optimize("boolean", "return TEST ? RESULT : false;").into(
- "return TEST && RESULT;");
+ optimize("boolean", "return b ? true : b1;").into("return b || b1;");
+ optimize("boolean", "return b ? false : b1;").into("return !b && b1;");
+ optimize("boolean", "return b ? b1 : true;").into("return !b || b1;");
+ optimize("boolean", "return b ? b1 : false;").into("return b && b1;");
}
public void testIfOptimizations() throws Exception {
@@ -69,12 +81,48 @@
optimize("int", "if (true) {} else return 4; return 0;").into("return 0;");
- addSnippetClassDecl("static volatile boolean TEST");
- addSnippetClassDecl("static boolean test() { return TEST; }");
+ addSnippetClassDecl("static boolean test() { return b; }");
optimize("int", "if (test()) {} else {}; return 0;").into(
"test(); return 0;");
}
+ public void testIfStatementToBoolean_NotOptimization() throws Exception {
+ optimize("void", "if (!b) i = 1;").intoString(
+ "EntryPoint.b || (EntryPoint.i = 1);");
+ optimize("void", "if (!b) i = 1; else i = 2;").intoString(
+ "EntryPoint.b ? (EntryPoint.i = 2) : (EntryPoint.i = 1);");
+ optimize("int", "if (!b) { return 1;} else {return 2;}").into(
+ "return b ? 2 : 1;");
+ }
+
+ public void testIfStatementToBoolean_ReturnLifting() throws Exception {
+ optimize("int", "if (b) return 1; return 2;").into(
+ "if (b) return 1; return 2;");
+ optimize("int", "if (b) { return 1; } return 2;").into(
+ "if (b) { return 1; } return 2;");
+ optimize("int", "if (b) { return 1;} else {return 2;}").into(
+ "return b ? 1 : 2;");
+ optimize("int", "if (b) return 1; else {return 2;}").into(
+ "return b ? 1 : 2;");
+ optimize("int", "if (b) return 1; else return 2;").into("return b ? 1 : 2;");
+ optimize("void", "if (b) return; else return;").into(
+ "if (b) return; else return;");
+ }
+
+ public void testIfStatementToBoolean_ThenElseOptimization() throws Exception {
+ optimize("void", "if (b) i = 1; else i = 2;").intoString(
+ "EntryPoint.b ? (EntryPoint.i = 1) : (EntryPoint.i = 2);");
+ optimize("void", "if (b) {i = 1;} else {i = 2;}").intoString(
+ "EntryPoint.b ? (EntryPoint.i = 1) : (EntryPoint.i = 2);");
+ }
+
+ public void testIfStatementToBoolean_ThenOptimization() throws Exception {
+ optimize("void", "if (b) i = 1;").intoString(
+ "EntryPoint.b && (EntryPoint.i = 1);");
+ optimize("void", "if (b) {i = 1;}").intoString(
+ "EntryPoint.b && (EntryPoint.i = 1);");
+ }
+
private Result optimize(final String returnType, final String codeSnippet)
throws UnableToCompleteException {
JProgram program = compileSnippet(returnType, codeSnippet);