Rewrites some if statements into boolean expressions.

if (<cond>) { <expr_stmt>; } => <cond> && <expr_stmt>

if (<cond>) { <then_expr_stmt>; } else { <else_expr_stmt>; } =>
   <cond> ? <then_expr_stmt> : <else_expr_stmt>

Also supports lifting return statement:

if (<cond>) { return <expr1>; } else { return <expr2>; } =>
  return <cond>?<expr1>:<expr2>;

Reduces code size by a tiny amout (approx. 0.5%) mostly because this optimization create a lot of opportunities for inliner to kick in.

Patch by: mike.aizatsky	
Review by: me, spoon



git-svn-id: https://google-web-toolkit.googlecode.com/svn/trunk@6533 8db76d5a-ed1c-0410-87a9-c151d255dfc7
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java b/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
index ee3e874..90505e7 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
@@ -89,6 +89,8 @@
    */
   public class DeadCodeVisitor extends JModVisitor {
 
+    private JMethod currentMethod = null;
+
     /**
      * Expressions whose result does not matter. A parent node should add any
      * children whose result does not matter to this set during the parent's
@@ -358,7 +360,7 @@
     @Override
     public void endVisit(JIfStatement x, Context ctx) {
       JStatement updated = simplifier.ifStatement(x, x.getSourceInfo(),
-          x.getIfExpr(), x.getThenStmt(), x.getElseStmt());
+          x.getIfExpr(), x.getThenStmt(), x.getElseStmt(), currentMethod);
       if (updated != x) {
         ctx.replaceMe(updated);
       }
@@ -373,6 +375,11 @@
       }
     }
 
+    @Override
+    public void endVisit(JMethod x, Context ctx) {
+      currentMethod = null;
+    }
+
     /**
      * Resolve method calls that can be computed statically.
      */
@@ -596,6 +603,12 @@
     }
 
     @Override
+    public boolean visit(JMethod x, Context ctx) {
+      currentMethod = x;
+      return true;
+    }
+
+    @Override
     public boolean visit(JMethodCall x, Context ctx) {
       JMethod target = x.getTarget();
       if (target.isStatic() && x.getInstance() != null) {
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java b/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
index 8e4f00e..d5929ac 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
@@ -17,6 +17,7 @@
 
 import com.google.gwt.dev.jjs.ast.Context;
 import com.google.gwt.dev.jjs.ast.JBinaryOperation;
+import com.google.gwt.dev.jjs.ast.JBinaryOperator;
 import com.google.gwt.dev.jjs.ast.JConditional;
 import com.google.gwt.dev.jjs.ast.JDeclarationStatement;
 import com.google.gwt.dev.jjs.ast.JExpression;
@@ -56,14 +57,21 @@
       JType lhsType = x.getLhs().getType();
       JType rhsType = x.getRhs().getType();
       JType resultType = x.getType();
+      JBinaryOperator op = x.getOp();
 
       if (resultType == program.getTypeJavaLangString()) {
         // Don't mess with concat.
         return;
       }
 
+      if (lhsType == JPrimitiveType.BOOLEAN
+          && (op == JBinaryOperator.AND || op == JBinaryOperator.OR)) {
+        // Don't mess with if rewriter.
+        return;
+      }
+
       // Special case: shift operators always coerce a long RHS to int.
-      if (x.getOp().isShiftOperator()) {
+      if (op.isShiftOperator()) {
         if (rhsType == longType) {
           rhsType = program.getTypePrimitiveInt();
         }
@@ -80,7 +88,7 @@
         if ((lhsType == floatType || lhsType == doubleType)) {
           coerceTo = lhsType;
         }
-        if (x.getOp().isAssignment()) {
+        if (op.isAssignment()) {
           // In an assignment, the lhs must coerce the rhs
           coerceTo = lhsType;
         } else if ((rhsType == floatType || rhsType == doubleType)) {
@@ -93,7 +101,7 @@
       JExpression newRhs = checkAndReplace(x.getRhs(), rhsType);
       if (newLhs != x.getLhs() || newRhs != x.getRhs()) {
         JBinaryOperation binOp = new JBinaryOperation(x.getSourceInfo(),
-            resultType, x.getOp(), newLhs, newRhs);
+            resultType, op, newLhs, newRhs);
         ctx.replaceMe(binOp);
       }
     }
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/LongEmulationNormalizer.java b/dev/core/src/com/google/gwt/dev/jjs/impl/LongEmulationNormalizer.java
index 5019920..d35dbc9 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/LongEmulationNormalizer.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/LongEmulationNormalizer.java
@@ -56,7 +56,6 @@
       JType lhsType = x.getLhs().getType();
       JType rhsType = x.getRhs().getType();
       if (lhsType != longType) {
-        assert (rhsType != longType);
         return;
       }
 
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java b/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java
index e376f3b..de4e38b 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java
@@ -23,10 +23,13 @@
 import com.google.gwt.dev.jjs.ast.JCastOperation;
 import com.google.gwt.dev.jjs.ast.JConditional;
 import com.google.gwt.dev.jjs.ast.JExpression;
+import com.google.gwt.dev.jjs.ast.JExpressionStatement;
 import com.google.gwt.dev.jjs.ast.JIfStatement;
+import com.google.gwt.dev.jjs.ast.JMethod;
 import com.google.gwt.dev.jjs.ast.JPrefixOperation;
 import com.google.gwt.dev.jjs.ast.JPrimitiveType;
 import com.google.gwt.dev.jjs.ast.JProgram;
+import com.google.gwt.dev.jjs.ast.JReturnStatement;
 import com.google.gwt.dev.jjs.ast.JStatement;
 import com.google.gwt.dev.jjs.ast.JType;
 import com.google.gwt.dev.jjs.ast.JUnaryOperation;
@@ -186,7 +189,8 @@
   }
 
   public JStatement ifStatement(JIfStatement original, SourceInfo sourceInfo,
-      JExpression condExpr, JStatement thenStmt, JStatement elseStmt) {
+      JExpression condExpr, JStatement thenStmt, JStatement elseStmt,
+      JMethod currentMethod) {
     if (condExpr instanceof JMultiExpression) {
       // if(a,b,c) d else e -> {a; b; if(c) d else e; }
       JMultiExpression condMulti = (JMultiExpression) condExpr;
@@ -195,7 +199,7 @@
         newBlock.addStmt(expr.makeStatement());
       }
       newBlock.addStmt(ifStatement(null, sourceInfo, last(condMulti.exprs),
-          thenStmt, elseStmt));
+          thenStmt, elseStmt, currentMethod));
       // TODO(spoon): immediately simplify the resulting block
       return newBlock;
     }
@@ -227,10 +231,17 @@
         // TODO: this goes away when we normalize the Java AST properly.
         thenStmt = ensureBlock(thenStmt);
         elseStmt = ensureBlock(elseStmt);
-        return ifStatement(null, sourceInfo, unflipped, elseStmt, thenStmt);
+        return ifStatement(null, sourceInfo, unflipped, elseStmt, thenStmt,
+            currentMethod);
       }
     }
 
+    JStatement rewritenStatement = rewriteIfIntoBoolean(sourceInfo, condExpr,
+        thenStmt, elseStmt, currentMethod);
+    if (rewritenStatement != null) {
+      return rewritenStatement;
+    }
+
     // no simplification made
     if (original != null) {
       return original;
@@ -304,4 +315,84 @@
     }
     return stmt;
   }
+
+  private JExpression extractExpression(JStatement stmt) {
+    if (stmt instanceof JExpressionStatement) {
+      JExpressionStatement statement = (JExpressionStatement) stmt;
+      return statement.getExpr();
+    }
+
+    return null;
+  }
+
+  private JStatement extractSingleStatement(JStatement stmt) {
+    if (stmt instanceof JBlock) {
+      JBlock block = (JBlock) stmt;
+      if (block.getStatements().size() == 1) {
+        return extractSingleStatement(block.getStatements().get(0));
+      }
+    }
+
+    return stmt;
+  }
+
+  private JStatement rewriteIfIntoBoolean(SourceInfo sourceInfo,
+      JExpression condExpr, JStatement thenStmt, JStatement elseStmt,
+      JMethod currentMethod) {
+    thenStmt = extractSingleStatement(thenStmt);
+    elseStmt = extractSingleStatement(elseStmt);
+
+    if (thenStmt instanceof JReturnStatement
+        && elseStmt instanceof JReturnStatement && currentMethod != null) {
+      // Special case
+      // if () { return ..; } else { return ..; } =>
+      // return ... ? ... : ...;
+      JExpression thenExpression = ((JReturnStatement) thenStmt).getExpr();
+      JExpression elseExpression = ((JReturnStatement) elseStmt).getExpr();
+      if (thenExpression == null || elseExpression == null) {
+        // empty returns are not supported.
+        return null;
+      }
+
+      JConditional conditional = new JConditional(sourceInfo,
+          currentMethod.getType(), condExpr, thenExpression, elseExpression);
+
+      JReturnStatement returnStatement = new JReturnStatement(sourceInfo,
+          conditional);
+      return returnStatement;
+    }
+
+    if (elseStmt != null) {
+      // if () { } else { } -> ... ? ... : ... ;
+      JExpression thenExpression = extractExpression(thenStmt);
+      JExpression elseExpression = extractExpression(elseStmt);
+
+      if (thenExpression != null && elseExpression != null) {
+        JConditional conditional = new JConditional(sourceInfo,
+            JPrimitiveType.VOID, condExpr, thenExpression, elseExpression);
+
+        return conditional.makeStatement();
+      }
+    } else {
+      // if () { } -> ... && ...;
+      JExpression thenExpression = extractExpression(thenStmt);
+
+      if (thenExpression != null) {
+        JBinaryOperator binaryOperator = JBinaryOperator.AND;
+
+        JExpression unflipExpression = maybeUnflipBoolean(condExpr);
+        if (unflipExpression != null) {
+          condExpr = unflipExpression;
+          binaryOperator = JBinaryOperator.OR;
+        }
+
+        JBinaryOperation binaryOperation = new JBinaryOperation(sourceInfo,
+            program.getTypeVoid(), binaryOperator, condExpr, thenExpression);
+
+        return binaryOperation.makeStatement();
+      }
+    }
+
+    return null;
+  }
 }
diff --git a/dev/core/test/com/google/gwt/dev/jjs/impl/DeadCodeEliminationTest.java b/dev/core/test/com/google/gwt/dev/jjs/impl/DeadCodeEliminationTest.java
index b26f401..02b1f2c 100644
--- a/dev/core/test/com/google/gwt/dev/jjs/impl/DeadCodeEliminationTest.java
+++ b/dev/core/test/com/google/gwt/dev/jjs/impl/DeadCodeEliminationTest.java
@@ -39,26 +39,38 @@
       expected = getMainMethodSource(program);
       assertEquals(userCode, expected, optimized);
     }
+
+    /**
+     * Compare without compiling expected, needed when optimizations produce
+     * incorrect java code (e.g. "a" || "b" is incorrect in java).
+     */
+    public void intoString(String expected) {
+      String actual = optimized;
+      assertTrue(actual.startsWith("{"));
+      assertTrue(actual.endsWith("}"));
+      actual = actual.substring(1, actual.length() - 2).trim();
+      // Join lines in actual code and remove indentations
+      actual = actual.replaceAll(" +", " ").replaceAll("\n", "");
+      assertEquals(userCode, expected, actual);
+    }
+  }
+
+  @Override
+  public void setUp() throws Exception {
+    addSnippetClassDecl("static volatile boolean b;");
+    addSnippetClassDecl("static volatile boolean b1;");
+    addSnippetClassDecl("static volatile int i;");
+    addSnippetClassDecl("static volatile long l;");
   }
 
   public void testConditionalOptimizations() throws Exception {
     optimize("int", "return true ? 3 : 4;").into("return 3;");
     optimize("int", "return false ? 3 : 4;").into("return 4;");
 
-    addSnippetClassDecl("static volatile boolean TEST");
-    addSnippetClassDecl("static volatile boolean RESULT");
-
-    optimize("boolean", "return TEST ? true : RESULT;").into(
-        "return TEST || RESULT;");
-
-    optimize("boolean", "return TEST ? false : RESULT;").into(
-        "return !TEST && RESULT;");
-
-    optimize("boolean", "return TEST ? RESULT : true;").into(
-        "return !TEST || RESULT;");
-
-    optimize("boolean", "return TEST ? RESULT : false;").into(
-        "return TEST && RESULT;");
+    optimize("boolean", "return b ? true : b1;").into("return b || b1;");
+    optimize("boolean", "return b ? false : b1;").into("return !b && b1;");
+    optimize("boolean", "return b ? b1 : true;").into("return !b || b1;");
+    optimize("boolean", "return b ? b1 : false;").into("return b && b1;");
   }
 
   public void testIfOptimizations() throws Exception {
@@ -69,12 +81,48 @@
 
     optimize("int", "if (true) {} else return 4; return 0;").into("return 0;");
 
-    addSnippetClassDecl("static volatile boolean TEST");
-    addSnippetClassDecl("static boolean test() { return TEST; }");
+    addSnippetClassDecl("static boolean test() { return b; }");
     optimize("int", "if (test()) {} else {}; return 0;").into(
         "test(); return 0;");
   }
 
+  public void testIfStatementToBoolean_NotOptimization() throws Exception {
+    optimize("void", "if (!b) i = 1;").intoString(
+        "EntryPoint.b || (EntryPoint.i = 1);");
+    optimize("void", "if (!b) i = 1; else i = 2;").intoString(
+        "EntryPoint.b ? (EntryPoint.i = 2) : (EntryPoint.i = 1);");
+    optimize("int", "if (!b) { return 1;} else {return 2;}").into(
+        "return b ? 2 : 1;");
+  }
+
+  public void testIfStatementToBoolean_ReturnLifting() throws Exception {
+    optimize("int", "if (b) return 1; return 2;").into(
+        "if (b) return 1; return 2;");
+    optimize("int", "if (b) { return 1; }  return 2;").into(
+        "if (b) { return 1; } return 2;");
+    optimize("int", "if (b) { return 1;} else {return 2;}").into(
+        "return b ? 1 : 2;");
+    optimize("int", "if (b) return 1; else {return 2;}").into(
+        "return b ? 1 : 2;");
+    optimize("int", "if (b) return 1; else return 2;").into("return b ? 1 : 2;");
+    optimize("void", "if (b) return; else return;").into(
+        "if (b) return; else return;");
+  }
+
+  public void testIfStatementToBoolean_ThenElseOptimization() throws Exception {
+    optimize("void", "if (b) i = 1; else i = 2;").intoString(
+        "EntryPoint.b ? (EntryPoint.i = 1) : (EntryPoint.i = 2);");
+    optimize("void", "if (b) {i = 1;} else {i = 2;}").intoString(
+        "EntryPoint.b ? (EntryPoint.i = 1) : (EntryPoint.i = 2);");
+  }
+
+  public void testIfStatementToBoolean_ThenOptimization() throws Exception {
+    optimize("void", "if (b) i = 1;").intoString(
+        "EntryPoint.b && (EntryPoint.i = 1);");
+    optimize("void", "if (b) {i = 1;}").intoString(
+        "EntryPoint.b && (EntryPoint.i = 1);");
+  }
+
   private Result optimize(final String returnType, final String codeSnippet)
       throws UnableToCompleteException {
     JProgram program = compileSnippet(returnType, codeSnippet);