Move multi-expressions out of the test of an if.  That is:

  (a,b,c)?d:e  ->  (a,b,(c?d:e))

Review by: scottb (desk check)


git-svn-id: https://google-web-toolkit.googlecode.com/svn/trunk@2977 8db76d5a-ed1c-0410-87a9-c151d255dfc7
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java b/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
index a9def73..c20e519 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/DeadCodeElimination.java
@@ -47,7 +47,6 @@
 import com.google.gwt.dev.jjs.ast.JParameterRef;
 import com.google.gwt.dev.jjs.ast.JPostfixOperation;
 import com.google.gwt.dev.jjs.ast.JPrefixOperation;
-import com.google.gwt.dev.jjs.ast.JPrimitiveType;
 import com.google.gwt.dev.jjs.ast.JProgram;
 import com.google.gwt.dev.jjs.ast.JReferenceType;
 import com.google.gwt.dev.jjs.ast.JStatement;
@@ -55,7 +54,6 @@
 import com.google.gwt.dev.jjs.ast.JSwitchStatement;
 import com.google.gwt.dev.jjs.ast.JTryStatement;
 import com.google.gwt.dev.jjs.ast.JType;
-import com.google.gwt.dev.jjs.ast.JUnaryOperation;
 import com.google.gwt.dev.jjs.ast.JUnaryOperator;
 import com.google.gwt.dev.jjs.ast.JValueLiteral;
 import com.google.gwt.dev.jjs.ast.JVariableRef;
@@ -85,8 +83,8 @@
    * operations in favor of pure side effects.
    * 
    * TODO(spoon): move more simplifications into methods like
-   * {@link #simplifyCast(JExpression, JType, JExpression) simplifyCast}, so
-   * that more simplifications can be made on a single pass through a tree.
+   * {@link #cast(JExpression, SourceInfo, JType, JExpression) simplifyCast},
+   * so that more simplifications can be made on a single pass through a tree.
    */
   public class DeadCodeVisitor extends JModVisitor {
 
@@ -105,14 +103,14 @@
      * existing uses of <code>ignoringExpressionOutput</code> are with mutable
      * nodes.
      */
-    private Set<JExpression> ignoringExpressionOutput = new HashSet<JExpression>();
+    private final Set<JExpression> ignoringExpressionOutput = new HashSet<JExpression>();
 
     /**
      * Expressions being used as lvalues.
      */
-    private Set<JExpression> lvalues = new HashSet<JExpression>();
+    private final Set<JExpression> lvalues = new HashSet<JExpression>();
 
-    private Set<JBlock> switchBlocks = new HashSet<JBlock>();
+    private final Set<JBlock> switchBlocks = new HashSet<JBlock>();
 
     /**
      * Short circuit binary operations.
@@ -218,6 +216,21 @@
           }
         }
 
+        if (stmt instanceof JExpressionStatement) {
+          JExpressionStatement stmtExpr = (JExpressionStatement) stmt;
+          if (stmtExpr.getExpr() instanceof JMultiExpression) {
+            // Promote a multi's expressions to the current block
+            x.statements.remove(i);
+            int start = i;
+            JMultiExpression multi = ((JMultiExpression) stmtExpr.getExpr());
+            for (JExpression expr : multi.exprs) {
+              x.statements.add(i++, expr.makeStatement());
+            }
+            i = start - 1;
+            continue;
+          }
+        }
+
         if (stmt.unconditionalControlBreak()) {
           // Abrupt change in flow, chop the remaining items from this block
           for (int j = i + 1; j < x.statements.size();) {
@@ -235,7 +248,8 @@
 
     @Override
     public void endVisit(JCastOperation x, Context ctx) {
-      JExpression updated = simplifyCast(x, x.getCastType(), x.getExpr());
+      JExpression updated = simplifier.cast(x, x.getSourceInfo(),
+          x.getCastType(), x.getExpr());
       if (updated != x) {
         ctx.replaceMe(updated);
       }
@@ -248,57 +262,10 @@
 
     @Override
     public void endVisit(JConditional x, Context ctx) {
-      JExpression condExpr = x.getIfTest();
-      JExpression thenExpr = x.getThenExpr();
-      JExpression elseExpr = x.getElseExpr();
-      if (condExpr instanceof JBooleanLiteral) {
-        if (((JBooleanLiteral) condExpr).getValue()) {
-          // e.g. (true ? then : else) -> then
-          ctx.replaceMe(thenExpr);
-        } else {
-          // e.g. (false ? then : else) -> else
-          ctx.replaceMe(elseExpr);
-        }
-      } else if (thenExpr instanceof JBooleanLiteral) {
-        if (((JBooleanLiteral) thenExpr).getValue()) {
-          // e.g. (cond ? true : else) -> cond || else
-          JBinaryOperation binOp = new JBinaryOperation(program,
-              x.getSourceInfo(), x.getType(), JBinaryOperator.OR, condExpr,
-              elseExpr);
-          ctx.replaceMe(binOp);
-        } else {
-          // e.g. (cond ? false : else) -> !cond && else
-          JPrefixOperation notCondExpr = new JPrefixOperation(program,
-              condExpr.getSourceInfo(), JUnaryOperator.NOT, condExpr);
-          JBinaryOperation binOp = new JBinaryOperation(program,
-              x.getSourceInfo(), x.getType(), JBinaryOperator.AND, notCondExpr,
-              elseExpr);
-          ctx.replaceMe(binOp);
-        }
-      } else if (elseExpr instanceof JBooleanLiteral) {
-        if (((JBooleanLiteral) elseExpr).getValue()) {
-          // e.g. (cond ? then : true) -> !cond || then
-          JPrefixOperation notCondExpr = new JPrefixOperation(program,
-              condExpr.getSourceInfo(), JUnaryOperator.NOT, condExpr);
-          JBinaryOperation binOp = new JBinaryOperation(program,
-              x.getSourceInfo(), x.getType(), JBinaryOperator.OR, notCondExpr,
-              thenExpr);
-          ctx.replaceMe(binOp);
-        } else {
-          // e.g. (cond ? then : false) -> cond && then
-          JBinaryOperation binOp = new JBinaryOperation(program,
-              x.getSourceInfo(), x.getType(), JBinaryOperator.AND, condExpr,
-              thenExpr);
-          ctx.replaceMe(binOp);
-        }
-      } else {
-        // e.g. (!cond ? then : else) -> (cond ? else : then)
-        JExpression unflipped = maybeUnflipBoolean(condExpr);
-        if (unflipped != null) {
-          ctx.replaceMe(new JConditional(program, x.getSourceInfo(),
-              x.getType(), unflipped, elseExpr, thenExpr));
-          return;
-        }
+      JExpression updated = simplifier.conditional(x, x.getSourceInfo(),
+          x.getType(), x.getIfTest(), x.getThenExpr(), x.getElseExpr());
+      if (updated != x) {
+        ctx.replaceMe(updated);
       }
     }
 
@@ -392,42 +359,10 @@
      */
     @Override
     public void endVisit(JIfStatement x, Context ctx) {
-      JExpression expr = x.getIfExpr();
-      JStatement thenStmt = x.getThenStmt();
-      JStatement elseStmt = x.getElseStmt();
-      if (expr instanceof JBooleanLiteral) {
-        JBooleanLiteral booleanLiteral = (JBooleanLiteral) expr;
-        boolean boolVal = booleanLiteral.getValue();
-        if (boolVal && !isEmpty(thenStmt)) {
-          // If true, replace myself with then statement
-          ctx.replaceMe(thenStmt);
-        } else if (!boolVal && !isEmpty(elseStmt)) {
-          // If false, replace myself with else statement
-          ctx.replaceMe(elseStmt);
-        } else {
-          // just prune me
-          removeMe(x, ctx);
-        }
-        return;
-      }
-
-      if (isEmpty(thenStmt) && isEmpty(elseStmt)) {
-        ctx.replaceMe(expr.makeStatement());
-        return;
-      }
-
-      if (!isEmpty(elseStmt)) {
-        // if (!cond) foo else bar -> if (cond) bar else foo
-        JExpression unflipped = maybeUnflipBoolean(expr);
-        if (unflipped != null) {
-          // Force sub-parts to blocks, otherwise we break else-if chains.
-          // TODO: this goes away when we normalize the Java AST properly.
-          thenStmt = ensureBlock(thenStmt);
-          elseStmt = ensureBlock(elseStmt);
-          ctx.replaceMe(new JIfStatement(program, x.getSourceInfo(), unflipped,
-              elseStmt, thenStmt));
-          return;
-        }
+      JStatement updated = simplifier.ifStatement(x, x.getSourceInfo(),
+          x.getIfExpr(), x.getThenStmt(), x.getElseStmt());
+      if (updated != x) {
+        ctx.replaceMe(updated);
       }
     }
 
@@ -550,46 +485,11 @@
         }
       }
       if (x.getOp() == JUnaryOperator.NOT) {
-        JExpression arg = x.getArg();
-        if (arg instanceof JBinaryOperation) {
-          // try to invert the binary operator
-          JBinaryOperation argOp = (JBinaryOperation) arg;
-          JBinaryOperator op = argOp.getOp();
-          JBinaryOperator newOp = null;
-          if (op == JBinaryOperator.EQ) {
-            // e.g. !(x == y) -> x != y
-            newOp = JBinaryOperator.NEQ;
-          } else if (op == JBinaryOperator.NEQ) {
-            // e.g. !(x != y) -> x == y
-            newOp = JBinaryOperator.EQ;
-          } else if (op == JBinaryOperator.GT) {
-            // e.g. !(x > y) -> x <= y
-            newOp = JBinaryOperator.LTE;
-          } else if (op == JBinaryOperator.LTE) {
-            // e.g. !(x <= y) -> x > y
-            newOp = JBinaryOperator.GT;
-          } else if (op == JBinaryOperator.GTE) {
-            // e.g. !(x >= y) -> x < y
-            newOp = JBinaryOperator.LT;
-          } else if (op == JBinaryOperator.LT) {
-            // e.g. !(x < y) -> x >= y
-            newOp = JBinaryOperator.GTE;
-          }
-          if (newOp != null) {
-            JBinaryOperation newBinOp = new JBinaryOperation(program,
-                argOp.getSourceInfo(), argOp.getType(), newOp, argOp.getLhs(),
-                argOp.getRhs());
-            ctx.replaceMe(newBinOp);
-          }
-        } else if (arg instanceof JPrefixOperation) {
-          // try to invert the unary operator
-          JPrefixOperation argOp = (JPrefixOperation) arg;
-          JUnaryOperator op = argOp.getOp();
-          // e.g. !!x -> x
-          if (op == JUnaryOperator.NOT) {
-            ctx.replaceMe(argOp.getArg());
-          }
+        JExpression updated = simplifier.not(x, x.getSourceInfo(), x.getArg());
+        if (updated != x) {
+          ctx.replaceMe(updated);
         }
+        return;
       } else if (x.getOp() == JUnaryOperator.NEG) {
         JExpression updated = simplifyNegate(x, x.getArg());
         if (updated != x) {
@@ -637,9 +537,9 @@
       }
 
       // Compute properties regarding the state of this try statement
-      boolean noTry = isEmpty(x.getTryBlock());
+      boolean noTry = Simplifier.isEmpty(x.getTryBlock());
       boolean noCatch = catchArgs.size() == 0;
-      boolean noFinally = isEmpty(x.getFinallyBlock());
+      boolean noFinally = Simplifier.isEmpty(x.getFinallyBlock());
 
       if (noTry) {
         // 2) Prune try statements with no body.
@@ -761,15 +661,6 @@
       return true;
     }
 
-    private JStatement ensureBlock(JStatement stmt) {
-      if (!(stmt instanceof JBlock)) {
-        JBlock block = new JBlock(program, stmt.getSourceInfo());
-        block.statements.add(stmt);
-        stmt = block;
-      }
-      return stmt;
-    }
-
     private void evalConcat(JExpression lhs, JExpression rhs, Context ctx) {
       if (lhs instanceof JValueLiteral && rhs instanceof JValueLiteral) {
         Object lhsObj = ((JValueLiteral) lhs).getValueObj();
@@ -1125,16 +1016,6 @@
       return true;
     }
 
-    /**
-     * TODO: if the AST were normalized, we wouldn't need this.
-     */
-    private boolean isEmpty(JStatement stmt) {
-      if (stmt == null) {
-        return true;
-      }
-      return (stmt instanceof JBlock && ((JBlock) stmt).statements.isEmpty());
-    }
-
     private boolean isLiteralNegativeOne(JExpression exp) {
       if (exp instanceof JValueLiteral) {
         JValueLiteral lit = (JValueLiteral) exp;
@@ -1284,20 +1165,6 @@
       return call;
     }
 
-    /**
-     * Negate the supplied expression if negating it makes the expression
-     * shorter. Otherwise, return null.
-     */
-    private JExpression maybeUnflipBoolean(JExpression expr) {
-      if (expr instanceof JUnaryOperation) {
-        JUnaryOperation unop = (JUnaryOperation) expr;
-        if (unop.getOp() == JUnaryOperator.NOT) {
-          return unop.getArg();
-        }
-      }
-      return null;
-    }
-
     private int numRemovableExpressions(JMultiExpression x) {
       if (ignoringExpressionOutput.contains(x)) {
         // The result doesn't matter: all expressions can be removed.
@@ -1437,11 +1304,11 @@
     private boolean simplifyAdd(JExpression lhs, JExpression rhs, Context ctx,
         JType type) {
       if (isLiteralZero(rhs)) {
-        ctx.replaceMe(simplifyCast(type, lhs));
+        ctx.replaceMe(simplifier.cast(type, lhs));
         return true;
       }
       if (isLiteralZero(lhs)) {
-        ctx.replaceMe(simplifyCast(type, rhs));
+        ctx.replaceMe(simplifier.cast(type, rhs));
         return true;
       }
 
@@ -1482,66 +1349,14 @@
       }
     }
 
-    /**
-     * Simplify a cast operation. Return <code>original</code> if it is
-     * equivalent to the desired return value.
-     * 
-     * TODO: Simplify casts of casts, e.g. (int)(long)foo.
-     * 
-     * @param original Either <code>null</code>, or a cast from
-     *          <code>exp</code> to <code>type</code>
-     * @param type The type to cast to
-     * @param exp The expression being cast
-     * @return An expression equivalent to a cast from <code>exp</code> to
-     *         <code>type</code>, but possibly simplified
-     */
-    private JExpression simplifyCast(JExpression original, JType type,
-        JExpression exp) {
-      if (type == exp.getType()) {
-        return exp;
-      }
-      if ((type instanceof JPrimitiveType) && (exp instanceof JValueLiteral)) {
-        // Statically evaluate casting literals.
-        JPrimitiveType typePrim = (JPrimitiveType) type;
-        JValueLiteral expLit = (JValueLiteral) exp;
-        JValueLiteral casted = typePrim.coerceLiteral(expLit);
-        if (casted != null) {
-          return casted;
-        }
-      }
-
-      /*
-       * Discard casts from byte or short to int, because such casts are always
-       * implicit anyway. Cannot coerce char since that would change the
-       * semantics of concat.
-       */
-      if (type == program.getTypePrimitiveInt()) {
-        JType expType = exp.getType();
-        if ((expType == program.getTypePrimitiveShort())
-            || (expType == program.getTypePrimitiveByte())) {
-          return exp;
-        }
-      }
-
-      // no simplification made
-      if (original != null) {
-        return original;
-      }
-      return new JCastOperation(program, exp.getSourceInfo(), type, exp);
-    }
-
-    private JExpression simplifyCast(JType type, JExpression exp) {
-      return simplifyCast(null, type, exp);
-    }
-
     private boolean simplifyDiv(JExpression lhs, JExpression rhs, Context ctx,
         JType type) {
       if (isLiteralOne(rhs)) {
-        ctx.replaceMe(simplifyCast(type, lhs));
+        ctx.replaceMe(simplifier.cast(type, lhs));
         return true;
       }
       if (isLiteralNegativeOne(rhs)) {
-        ctx.replaceMe(simplifyNegate(simplifyCast(type, lhs)));
+        ctx.replaceMe(simplifyNegate(simplifier.cast(type, lhs)));
         return true;
       }
 
@@ -1563,27 +1378,27 @@
     private boolean simplifyMul(JExpression lhs, JExpression rhs, Context ctx,
         JType type) {
       if (isLiteralOne(rhs)) {
-        ctx.replaceMe(simplifyCast(type, lhs));
+        ctx.replaceMe(simplifier.cast(type, lhs));
         return true;
       }
       if (isLiteralOne(lhs)) {
-        ctx.replaceMe(simplifyCast(type, rhs));
+        ctx.replaceMe(simplifier.cast(type, rhs));
         return true;
       }
       if (isLiteralNegativeOne(rhs)) {
-        ctx.replaceMe(simplifyNegate(simplifyCast(type, lhs)));
+        ctx.replaceMe(simplifyNegate(simplifier.cast(type, lhs)));
         return true;
       }
       if (isLiteralNegativeOne(lhs)) {
-        ctx.replaceMe(simplifyNegate(simplifyCast(type, rhs)));
+        ctx.replaceMe(simplifyNegate(simplifier.cast(type, rhs)));
         return true;
       }
       if (isLiteralZero(rhs) && !lhs.hasSideEffects()) {
-        ctx.replaceMe(simplifyCast(type, rhs));
+        ctx.replaceMe(simplifier.cast(type, rhs));
         return true;
       }
       if (isLiteralZero(lhs) && !rhs.hasSideEffects()) {
-        ctx.replaceMe(simplifyCast(type, lhs));
+        ctx.replaceMe(simplifier.cast(type, lhs));
         return true;
       }
       return false;
@@ -1621,11 +1436,11 @@
     private boolean simplifySub(JExpression lhs, JExpression rhs, Context ctx,
         JType type) {
       if (isLiteralZero(rhs)) {
-        ctx.replaceMe(simplifyCast(type, lhs));
+        ctx.replaceMe(simplifier.cast(type, lhs));
         return true;
       }
       if (isLiteralZero(lhs)) {
-        ctx.replaceMe(simplifyNegate(simplifyCast(type, rhs)));
+        ctx.replaceMe(simplifyNegate(simplifier.cast(type, rhs)));
         return true;
       }
       return false;
@@ -1906,11 +1721,13 @@
   }
 
   private final JProgram program;
+  private final Simplifier simplifier;
 
   private final Map<JType, Class<?>> typeClassMap = new IdentityHashMap<JType, Class<?>>();
 
   public DeadCodeElimination(JProgram program) {
     this.program = program;
+    simplifier = new Simplifier(program);
     typeClassMap.put(program.getTypeJavaLangObject(), Object.class);
     typeClassMap.put(program.getTypeJavaLangString(), String.class);
     typeClassMap.put(program.getTypePrimitiveBoolean(), boolean.class);
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java b/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
index cba4ebe..74ba0cb 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/LongCastNormalizer.java
@@ -17,10 +17,9 @@
 
 import com.google.gwt.dev.jjs.ast.Context;
 import com.google.gwt.dev.jjs.ast.JBinaryOperation;
-import com.google.gwt.dev.jjs.ast.JCastOperation;
 import com.google.gwt.dev.jjs.ast.JConditional;
-import com.google.gwt.dev.jjs.ast.JExpression;
 import com.google.gwt.dev.jjs.ast.JDeclarationStatement;
+import com.google.gwt.dev.jjs.ast.JExpression;
 import com.google.gwt.dev.jjs.ast.JMethod;
 import com.google.gwt.dev.jjs.ast.JMethodCall;
 import com.google.gwt.dev.jjs.ast.JModVisitor;
@@ -30,7 +29,6 @@
 import com.google.gwt.dev.jjs.ast.JProgram;
 import com.google.gwt.dev.jjs.ast.JReturnStatement;
 import com.google.gwt.dev.jjs.ast.JType;
-import com.google.gwt.dev.jjs.ast.JValueLiteral;
 
 import java.util.List;
 
@@ -182,28 +180,14 @@
     /**
      * Returns an explicit cast if the target type is long and the input
      * expression is not a long, or if the target type is floating point and the
-     * expression is a long.
+     * expression is a long. TODO(spoon): there is no floating point in this
+     * method; update the comment
      */
     private JExpression checkAndReplace(JExpression arg, JType targetType) {
-      JType argType = arg.getType();
-      if (targetType == argType) {
+      if (targetType != longType && arg.getType() != longType) {
         return arg;
       }
-      if (targetType != longType && argType != longType) {
-        return arg;
-      }
-      if (arg instanceof JValueLiteral && targetType instanceof JPrimitiveType) {
-        // Attempt to coerce the literal.
-        JPrimitiveType primitiveType = (JPrimitiveType) targetType;
-        JValueLiteral coerced = primitiveType.coerceLiteral((JValueLiteral) arg);
-        if (coerced != null) {
-          return coerced;
-        }
-      }
-      // Synthesize a cast to long to force explicit conversion.
-      JCastOperation cast = new JCastOperation(program, arg.getSourceInfo(),
-          targetType, arg);
-      return cast;
+      return simplifier.cast(targetType, arg);
     }
   }
 
@@ -212,9 +196,11 @@
   }
 
   private final JProgram program;
+  private final Simplifier simplifier;
 
   private LongCastNormalizer(JProgram program) {
     this.program = program;
+    simplifier = new Simplifier(program);
   }
 
   private void execImpl() {
@@ -222,5 +208,4 @@
         program.getTypePrimitiveLong());
     visitor.accept(program);
   }
-
 }
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java b/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java
new file mode 100644
index 0000000..f103083
--- /dev/null
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/Simplifier.java
@@ -0,0 +1,315 @@
+/*
+ * Copyright 2008 Google Inc.
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.gwt.dev.jjs.impl;
+
+import com.google.gwt.dev.jjs.SourceInfo;
+import com.google.gwt.dev.jjs.ast.JBinaryOperation;
+import com.google.gwt.dev.jjs.ast.JBinaryOperator;
+import com.google.gwt.dev.jjs.ast.JBlock;
+import com.google.gwt.dev.jjs.ast.JBooleanLiteral;
+import com.google.gwt.dev.jjs.ast.JCastOperation;
+import com.google.gwt.dev.jjs.ast.JConditional;
+import com.google.gwt.dev.jjs.ast.JExpression;
+import com.google.gwt.dev.jjs.ast.JIfStatement;
+import com.google.gwt.dev.jjs.ast.JPrefixOperation;
+import com.google.gwt.dev.jjs.ast.JPrimitiveType;
+import com.google.gwt.dev.jjs.ast.JProgram;
+import com.google.gwt.dev.jjs.ast.JStatement;
+import com.google.gwt.dev.jjs.ast.JType;
+import com.google.gwt.dev.jjs.ast.JUnaryOperation;
+import com.google.gwt.dev.jjs.ast.JUnaryOperator;
+import com.google.gwt.dev.jjs.ast.JValueLiteral;
+import com.google.gwt.dev.jjs.ast.js.JMultiExpression;
+
+import java.util.List;
+
+/**
+ * Methods that both construct and try to simplify AST nodes. If simplification
+ * fails, then the methods will return an original, unmodified version of the
+ * node if one is supplied. The routines do not recurse into their arguments;
+ * the arguments are assumed to already be simplified as much as possible.
+ */
+public class Simplifier {
+  /**
+   * TODO: if the AST were normalized, we wouldn't need this.
+   */
+  public static boolean isEmpty(JStatement stmt) {
+    if (stmt == null) {
+      return true;
+    }
+    return (stmt instanceof JBlock && ((JBlock) stmt).statements.isEmpty());
+  }
+
+  /**
+   * Negate the supplied expression if negating it makes the expression shorter.
+   * Otherwise, return null.
+   */
+  static JExpression maybeUnflipBoolean(JExpression expr) {
+    if (expr instanceof JUnaryOperation) {
+      JUnaryOperation unop = (JUnaryOperation) expr;
+      if (unop.getOp() == JUnaryOperator.NOT) {
+        return unop.getArg();
+      }
+    }
+    return null;
+  }
+
+  private static <T> List<T> allButLast(List<T> list) {
+    return list.subList(0, list.size() - 1);
+  }
+
+  private static <T> T last(List<T> list) {
+    return list.get(list.size() - 1);
+  }
+
+  private final JProgram program;
+
+  public Simplifier(JProgram program) {
+    this.program = program;
+  }
+
+  public JExpression cast(JExpression original, SourceInfo sourceInfo,
+      JType type, JExpression exp) {
+    if (type == exp.getType()) {
+      return exp;
+    }
+    if ((type instanceof JPrimitiveType) && (exp instanceof JValueLiteral)) {
+      // Statically evaluate casting literals.
+      JPrimitiveType typePrim = (JPrimitiveType) type;
+      JValueLiteral expLit = (JValueLiteral) exp;
+      JValueLiteral casted = typePrim.coerceLiteral(expLit);
+      if (casted != null) {
+        return casted;
+      }
+    }
+
+    /*
+     * Discard casts from byte or short to int, because such casts are always
+     * implicit anyway. Cannot coerce char since that would change the semantics
+     * of concat.
+     */
+    if (type == program.getTypePrimitiveInt()) {
+      JType expType = exp.getType();
+      if ((expType == program.getTypePrimitiveShort())
+          || (expType == program.getTypePrimitiveByte())) {
+        return exp;
+      }
+    }
+
+    // no simplification made
+    if (original != null) {
+      return original;
+    }
+    return new JCastOperation(program, exp.getSourceInfo(), type, exp);
+  }
+
+  public JExpression cast(JType type, JExpression exp) {
+    return cast(null, exp.getSourceInfo(), type, exp);
+  }
+
+  public JExpression conditional(JConditional original, SourceInfo sourceInfo,
+      JType type, JExpression condExpr, JExpression thenExpr,
+      JExpression elseExpr) {
+    if (condExpr instanceof JMultiExpression) {
+      // (a,b,c)?d:e -> a,b,(c?d:e)
+      // TODO(spoon): do this outward multi movement for all AST nodes
+      JMultiExpression condMulti = (JMultiExpression) condExpr;
+      JMultiExpression newMulti = new JMultiExpression(program, sourceInfo);
+      newMulti.exprs.addAll(allButLast(condMulti.exprs));
+      newMulti.exprs.add(conditional(null, sourceInfo, type,
+          last(condMulti.exprs), thenExpr, elseExpr));
+      // TODO(spoon): immediately simplify the resulting multi
+      return newMulti;
+    }
+    if (condExpr instanceof JBooleanLiteral) {
+      if (((JBooleanLiteral) condExpr).getValue()) {
+        // e.g. (true ? then : else) -> then
+        return thenExpr;
+      } else {
+        // e.g. (false ? then : else) -> else
+        return elseExpr;
+      }
+    } else if (thenExpr instanceof JBooleanLiteral) {
+      if (((JBooleanLiteral) thenExpr).getValue()) {
+        // e.g. (cond ? true : else) -> cond || else
+        JBinaryOperation binOp = new JBinaryOperation(program,
+            original.getSourceInfo(), original.getType(), JBinaryOperator.OR,
+            condExpr, elseExpr);
+        return binOp;
+      } else {
+        // e.g. (cond ? false : else) -> !cond && else
+        JPrefixOperation notCondExpr = new JPrefixOperation(program,
+            condExpr.getSourceInfo(), JUnaryOperator.NOT, condExpr);
+        JBinaryOperation binOp = new JBinaryOperation(program,
+            original.getSourceInfo(), original.getType(), JBinaryOperator.AND,
+            notCondExpr, elseExpr);
+        return binOp;
+      }
+    } else if (elseExpr instanceof JBooleanLiteral) {
+      if (((JBooleanLiteral) elseExpr).getValue()) {
+        // e.g. (cond ? then : true) -> !cond || then
+        JPrefixOperation notCondExpr = new JPrefixOperation(program,
+            condExpr.getSourceInfo(), JUnaryOperator.NOT, condExpr);
+        JBinaryOperation binOp = new JBinaryOperation(program,
+            original.getSourceInfo(), original.getType(), JBinaryOperator.OR,
+            notCondExpr, thenExpr);
+        return binOp;
+      } else {
+        // e.g. (cond ? then : false) -> cond && then
+        JBinaryOperation binOp = new JBinaryOperation(program,
+            original.getSourceInfo(), original.getType(), JBinaryOperator.AND,
+            condExpr, thenExpr);
+        return binOp;
+      }
+    } else {
+      // e.g. (!cond ? then : else) -> (cond ? else : then)
+      JExpression unflipped = maybeUnflipBoolean(condExpr);
+      if (unflipped != null) {
+        return new JConditional(program, original.getSourceInfo(),
+            original.getType(), unflipped, elseExpr, thenExpr);
+      }
+    }
+
+    // no simplification made
+    if (original != null) {
+      return original;
+    }
+    return new JConditional(program, sourceInfo, type, condExpr, thenExpr,
+        elseExpr);
+  }
+
+  public JStatement ifStatement(JIfStatement original, SourceInfo sourceInfo,
+      JExpression condExpr, JStatement thenStmt, JStatement elseStmt) {
+    if (condExpr instanceof JMultiExpression) {
+      // if(a,b,c) d else e -> {a; b; if(c) d else e; }
+      JMultiExpression condMulti = (JMultiExpression) condExpr;
+      JBlock newBlock = new JBlock(program, sourceInfo);
+      for (JExpression expr : allButLast(condMulti.exprs)) {
+        newBlock.statements.add(expr.makeStatement());
+      }
+      newBlock.statements.add(ifStatement(null, sourceInfo,
+          last(condMulti.exprs), thenStmt, elseStmt));
+      // TODO(spoon): immediately simplify the resulting block
+      return newBlock;
+    }
+
+    if (condExpr instanceof JBooleanLiteral) {
+      JBooleanLiteral booleanLiteral = (JBooleanLiteral) condExpr;
+      boolean boolVal = booleanLiteral.getValue();
+      if (boolVal && !isEmpty(thenStmt)) {
+        // If true, replace myself with then statement
+        return thenStmt;
+      } else if (!boolVal && !isEmpty(elseStmt)) {
+        // If false, replace myself with else statement
+        return elseStmt;
+      } else {
+        // just prune me
+        return condExpr.makeStatement();
+      }
+    }
+
+    if (isEmpty(thenStmt) && isEmpty(elseStmt)) {
+      return condExpr.makeStatement();
+    }
+
+    if (!isEmpty(elseStmt)) {
+      // if (!cond) foo else bar -> if (cond) bar else foo
+      JExpression unflipped = Simplifier.maybeUnflipBoolean(condExpr);
+      if (unflipped != null) {
+        // Force sub-parts to blocks, otherwise we break else-if chains.
+        // TODO: this goes away when we normalize the Java AST properly.
+        thenStmt = ensureBlock(thenStmt);
+        elseStmt = ensureBlock(elseStmt);
+        return ifStatement(null, sourceInfo, unflipped, elseStmt, thenStmt);
+      }
+    }
+
+    // no simplification made
+    if (original != null) {
+      return original;
+    }
+    return new JIfStatement(program, condExpr.getSourceInfo(), condExpr,
+        thenStmt, elseStmt);
+  }
+
+  public JExpression not(JPrefixOperation original, SourceInfo sourceInfo,
+      JExpression arg) {
+    if (arg instanceof JMultiExpression) {
+      // !(a,b,c) -> (a,b,!c)
+      JMultiExpression argMulti = (JMultiExpression) arg;
+      JMultiExpression newMulti = new JMultiExpression(program, sourceInfo);
+      newMulti.exprs.addAll(allButLast(argMulti.exprs));
+      newMulti.exprs.add(not(null, sourceInfo, last(argMulti.exprs)));
+      // TODO(spoon): immediately simplify the newMulti
+      return newMulti;
+    }
+    if (arg instanceof JBinaryOperation) {
+      // try to invert the binary operator
+      JBinaryOperation argOp = (JBinaryOperation) arg;
+      JBinaryOperator op = argOp.getOp();
+      JBinaryOperator newOp = null;
+      if (op == JBinaryOperator.EQ) {
+        // e.g. !(x == y) -> x != y
+        newOp = JBinaryOperator.NEQ;
+      } else if (op == JBinaryOperator.NEQ) {
+        // e.g. !(x != y) -> x == y
+        newOp = JBinaryOperator.EQ;
+      } else if (op == JBinaryOperator.GT) {
+        // e.g. !(x > y) -> x <= y
+        newOp = JBinaryOperator.LTE;
+      } else if (op == JBinaryOperator.LTE) {
+        // e.g. !(x <= y) -> x > y
+        newOp = JBinaryOperator.GT;
+      } else if (op == JBinaryOperator.GTE) {
+        // e.g. !(x >= y) -> x < y
+        newOp = JBinaryOperator.LT;
+      } else if (op == JBinaryOperator.LT) {
+        // e.g. !(x < y) -> x >= y
+        newOp = JBinaryOperator.GTE;
+      }
+      if (newOp != null) {
+        JBinaryOperation newBinOp = new JBinaryOperation(program,
+            argOp.getSourceInfo(), argOp.getType(), newOp, argOp.getLhs(),
+            argOp.getRhs());
+        return newBinOp;
+      }
+    } else if (arg instanceof JPrefixOperation) {
+      // try to invert the unary operator
+      JPrefixOperation argOp = (JPrefixOperation) arg;
+      JUnaryOperator op = argOp.getOp();
+      // e.g. !!x -> x
+      if (op == JUnaryOperator.NOT) {
+        return argOp.getArg();
+      }
+    }
+
+    // no simplification made
+    if (original != null) {
+      return original;
+    }
+    return new JPrefixOperation(program, arg.getSourceInfo(),
+        JUnaryOperator.NOT, arg);
+  }
+
+  private JStatement ensureBlock(JStatement stmt) {
+    if (!(stmt instanceof JBlock)) {
+      JBlock block = new JBlock(program, stmt.getSourceInfo());
+      block.statements.add(stmt);
+      stmt = block;
+    }
+    return stmt;
+  }
+}