Fix handling of method references with type parameters.

Change-Id: I3921730dcffa842d5c41d4b36c10f04d483e22ee
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java b/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
index f8f9beb..cb066da 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
@@ -877,117 +877,113 @@
 
     @Override
     public void endVisit(ForeachStatement x, BlockScope scope) {
-      try {
-        SourceInfo info = makeSourceInfo(x);
+      SourceInfo info = makeSourceInfo(x);
 
-        JBlock body = popBlock(info, x.action);
-        JExpression collection = pop(x.collection);
-        JDeclarationStatement elementDecl = pop(x.elementVariable);
-        assert (elementDecl.initializer == null);
+      JBlock body = popBlock(info, x.action);
+      JExpression collection = pop(x.collection);
+      JDeclarationStatement elementDecl = pop(x.elementVariable);
+      assert (elementDecl.initializer == null);
 
-        JLocal elementVar = (JLocal) curMethod.locals.get(x.elementVariable.binding);
-        String elementVarName = elementVar.getName();
+      JLocal elementVar = (JLocal) curMethod.locals.get(x.elementVariable.binding);
+      String elementVarName = elementVar.getName();
 
-        JForStatement result;
-        if (x.collectionVariable != null) {
-          /**
-           * <pre>
-         * for (final T[] i$array = collection,
-         *          int i$index = 0,
-         *          final int i$max = i$array.length;
-         *      i$index < i$max; ++i$index) {
-         *   T elementVar = i$array[i$index];
+      JForStatement result;
+      if (x.collectionVariable != null) {
+        /**
+         * <pre>
+       * for (final T[] i$array = collection,
+       *          int i$index = 0,
+       *          final int i$max = i$array.length;
+       *      i$index < i$max; ++i$index) {
+       *   T elementVar = i$array[i$index];
+       *   // user action
+       * }
+       * </pre>
+         */
+        JLocal arrayVar = JProgram.createLocal(info, elementVarName + "$array",
+            typeMap.get(x.collection.resolvedType), true, curMethod.body);
+        JLocal indexVar =
+            JProgram.createLocal(info, elementVarName + "$index", JPrimitiveType.INT, false,
+                curMethod.body);
+        JLocal maxVar =
+            JProgram.createLocal(info, elementVarName + "$max", JPrimitiveType.INT, true,
+                curMethod.body);
+
+        List<JStatement> initializers = Lists.newArrayListWithCapacity(3);
+        // T[] i$array = arr
+        initializers.add(makeDeclaration(info, arrayVar, collection));
+        // int i$index = 0
+        initializers.add(makeDeclaration(info, indexVar, JIntLiteral.get(0)));
+        // int i$max = i$array.length
+        initializers.add(makeDeclaration(info, maxVar,
+            new JArrayLength(info, arrayVar.makeRef(info))));
+
+        // i$index < i$max
+        JExpression condition =
+            new JBinaryOperation(info, JPrimitiveType.BOOLEAN, JBinaryOperator.LT,
+                indexVar.makeRef(info), maxVar.makeRef(info));
+
+        // ++i$index
+        JExpression increments = new JPrefixOperation(info, JUnaryOperator.INC,
+            indexVar.makeRef(info));
+
+        // T elementVar = i$array[i$index];
+        elementDecl.initializer =
+            new JArrayRef(info, arrayVar.makeRef(info), indexVar.makeRef(info));
+        body.addStmt(0, elementDecl);
+
+        result = new JForStatement(info, initializers, condition, increments, body);
+      } else {
+        /**
+         * <pre>
+         * for (Iterator&lt;T&gt; i$iterator = collection.iterator(); i$iterator.hasNext();) {
+         *   T elementVar = i$iterator.next();
          *   // user action
          * }
          * </pre>
-           */
-          JLocal arrayVar = JProgram.createLocal(info, elementVarName + "$array",
-              typeMap.get(x.collection.resolvedType), true, curMethod.body);
-          JLocal indexVar =
-              JProgram.createLocal(info, elementVarName + "$index", JPrimitiveType.INT, false,
-                  curMethod.body);
-          JLocal maxVar =
-              JProgram.createLocal(info, elementVarName + "$max", JPrimitiveType.INT, true,
-                  curMethod.body);
+         */
+        CompilationUnitScope cudScope = scope.compilationUnitScope();
+        ReferenceBinding javaUtilIterator = scope.getJavaUtilIterator();
+        ReferenceBinding javaLangIterable = scope.getJavaLangIterable();
+        MethodBinding iterator = javaLangIterable.getExactMethod(ITERATOR_, NO_TYPES, cudScope);
+        MethodBinding hasNext = javaUtilIterator.getExactMethod(HAS_NEXT_, NO_TYPES, cudScope);
+        MethodBinding next = javaUtilIterator.getExactMethod(NEXT_, NO_TYPES, cudScope);
+        JLocal iteratorVar =
+            JProgram.createLocal(info, (elementVarName + "$iterator"), typeMap
+                .get(javaUtilIterator), false, curMethod.body);
 
-          List<JStatement> initializers = Lists.newArrayListWithCapacity(3);
-          // T[] i$array = arr
-          initializers.add(makeDeclaration(info, arrayVar, collection));
-          // int i$index = 0
-          initializers.add(makeDeclaration(info, indexVar, JIntLiteral.get(0)));
-          // int i$max = i$array.length
-          initializers.add(makeDeclaration(info, maxVar,
-              new JArrayLength(info, arrayVar.makeRef(info))));
+        List<JStatement> initializers = Lists.newArrayListWithCapacity(1);
+        // Iterator<T> i$iterator = collection.iterator()
+        initializers.add(makeDeclaration(info, iteratorVar, new JMethodCall(info, collection,
+            typeMap.get(iterator))));
 
-          // i$index < i$max
-          JExpression condition =
-              new JBinaryOperation(info, JPrimitiveType.BOOLEAN, JBinaryOperator.LT,
-                  indexVar.makeRef(info), maxVar.makeRef(info));
+        // i$iterator.hasNext()
+        JExpression condition =
+            new JMethodCall(info, iteratorVar.makeRef(info), typeMap.get(hasNext));
 
-          // ++i$index
-          JExpression increments = new JPrefixOperation(info, JUnaryOperator.INC,
-              indexVar.makeRef(info));
+        // T elementVar = (T) i$iterator.next();
+        elementDecl.initializer =
+            new JMethodCall(info, iteratorVar.makeRef(info), typeMap.get(next));
 
-          // T elementVar = i$array[i$index];
-          elementDecl.initializer =
-              new JArrayRef(info, arrayVar.makeRef(info), indexVar.makeRef(info));
-          body.addStmt(0, elementDecl);
-
-          result = new JForStatement(info, initializers, condition, increments, body);
-        } else {
-          /**
-           * <pre>
-           * for (Iterator&lt;T&gt; i$iterator = collection.iterator(); i$iterator.hasNext();) {
-           *   T elementVar = i$iterator.next();
-           *   // user action
-           * }
-           * </pre>
-           */
-          CompilationUnitScope cudScope = scope.compilationUnitScope();
-          ReferenceBinding javaUtilIterator = scope.getJavaUtilIterator();
-          ReferenceBinding javaLangIterable = scope.getJavaLangIterable();
-          MethodBinding iterator = javaLangIterable.getExactMethod(ITERATOR_, NO_TYPES, cudScope);
-          MethodBinding hasNext = javaUtilIterator.getExactMethod(HAS_NEXT_, NO_TYPES, cudScope);
-          MethodBinding next = javaUtilIterator.getExactMethod(NEXT_, NO_TYPES, cudScope);
-          JLocal iteratorVar =
-              JProgram.createLocal(info, (elementVarName + "$iterator"), typeMap
-                  .get(javaUtilIterator), false, curMethod.body);
-
-          List<JStatement> initializers = Lists.newArrayListWithCapacity(1);
-          // Iterator<T> i$iterator = collection.iterator()
-          initializers.add(makeDeclaration(info, iteratorVar, new JMethodCall(info, collection,
-              typeMap.get(iterator))));
-
-          // i$iterator.hasNext()
-          JExpression condition =
-              new JMethodCall(info, iteratorVar.makeRef(info), typeMap.get(hasNext));
-
-          // T elementVar = (T) i$iterator.next();
-          elementDecl.initializer =
-              new JMethodCall(info, iteratorVar.makeRef(info), typeMap.get(next));
-
-          // Perform any implicit reference type casts (due to generics).
-          // Note this occurs before potential unboxing.
-          if (elementVar.getType() != javaLangObject) {
-            TypeBinding collectionElementType = (TypeBinding) collectionElementTypeField.get(x);
-            JType toType = typeMap.get(collectionElementType);
-            assert (toType instanceof JReferenceType);
-            elementDecl.initializer = maybeCast(toType, elementDecl.initializer);
-          }
-
-          body.addStmt(0, elementDecl);
-
-          result = new JForStatement(info, initializers, condition,
-              null, body);
+        // Perform any implicit reference type casts (due to generics).
+        // Note this occurs before potential unboxing.
+        if (elementVar.getType() != javaLangObject) {
+          TypeBinding collectionElementType = getCollectionElementTypeBinding(x);
+          JType toType = typeMap.get(collectionElementType);
+          assert (toType instanceof JReferenceType);
+          elementDecl.initializer = maybeCast(toType, elementDecl.initializer);
         }
 
-        // May need to box or unbox the element assignment.
-        elementDecl.initializer =
-            maybeBoxOrUnbox(elementDecl.initializer, x.elementVariableImplicitWidening);
-        push(result);
-      } catch (Throwable e) {
-        throw translateException(x, e);
+        body.addStmt(0, elementDecl);
+
+        result = new JForStatement(info, initializers, condition,
+            null, body);
       }
+
+      // May need to box or unbox the element assignment.
+      elementDecl.initializer =
+          maybeBoxOrUnbox(elementDecl.initializer, x.elementVariableImplicitWidening);
+      push(result);
     }
 
     @Override
@@ -1100,9 +1096,12 @@
           body.getBlock().addStmt(newArray.makeReturnStatement());
           synthMethod.setBody(body);
         }
-        push(null); // no qualifier
       }
-      return true;
+
+      if (hasQualifier(x)) {
+        x.lhs.traverse(this, blockScope);
+      }
+      return false;
     }
 
     @Override
@@ -1755,17 +1754,12 @@
         }
       }
       JMethod referredMethod = typeMap.get(referredMethodBinding);
-      boolean haveReceiver = false;
-      try {
-        haveReceiver = (Boolean) haveReceiverField.get(x);
-      } catch (IllegalAccessException e) {
-        throw translateException(x, e);
-      }
+      boolean hasQualifier = hasQualifier(x);
 
       // Constructors and overloading means we need generate unique names
       String lambdaName = classNameForMethodReference(funcType,
           referredMethod,
-          haveReceiver);
+          hasQualifier);
 
       List<JExpression> enclosingThisRefs = Lists.newArrayList();
 
@@ -1784,7 +1778,7 @@
 
         List<JField> enclosingInstanceFields = new ArrayList<JField>();
         // If we have a qualifier instance, we have to stash it in the constructor
-        if (haveReceiver) {
+        if (hasQualifier) {
           // this.$$outer = $$outer
           JField outerField = createAndBindCapturedLambdaParameter(info, OUTER_LAMBDA_PARAM_NAME,
               innerLambdaClass.getEnclosingType(), ctor, ctorBody);
@@ -1828,7 +1822,7 @@
         // Comparator<T>.
         // The first argument serves as the qualifier, so for example, the method dispatch looks
         // like this: int compare(T a, T b) { a.compareTo(b); }
-        if (!haveReceiver && !referredMethod.isStatic() && instance == null &&
+        if (!hasQualifier && !referredMethod.isStatic() && instance == null &&
             samMethod.getParams().size() == referredMethod.getParams().size() + 1) {
           // the instance qualifier is the first parameter in this case.
           // Needs to be cast the actual type due to generics.
@@ -1927,8 +1921,9 @@
       // Replace the ReferenceExpression qualifier::method with new lambdaType(qualifier)
       assert lambdaCtor.getEnclosingType() == innerLambdaClass;
       JNewInstance allocLambda = new JNewInstance(info, lambdaCtor);
-      JExpression qualifier = (JExpression) pop();
-      if (haveReceiver) {
+
+      if (hasQualifier) {
+        JExpression qualifier =  (JExpression) pop();
         // pop qualifier from stack
         allocLambda.addArg(qualifier);
       } else {
@@ -1938,7 +1933,6 @@
           allocLambda.addArg(enclosingRef);
         }
       }
-
       push(allocLambda);
     }
 
@@ -3636,35 +3630,11 @@
   private static final char[] ITERATOR_ = ITERATOR_METHOD_NAME.toCharArray();
   private static final char[] HAS_NEXT_ = HAS_NEXT_METHOD_NAME.toCharArray();
 
-  /**
-   * Reflective access to {@link ForeachStatement#collectionElementType}.
-   */
-  private static final Field collectionElementTypeField;
-  /**
-   * Reflective access to {@link ReferenceExpression#haveReceiver}.
-   */
-  private static final Field haveReceiverField;
-
   private static final TypeBinding[] NO_TYPES = new TypeBinding[0];
   private static final Interner<String> stringInterner = StringInterner.get();
 
   static {
     InternalCompilerException.preload();
-    try {
-      collectionElementTypeField = ForeachStatement.class.getDeclaredField("collectionElementType");
-      collectionElementTypeField.setAccessible(true);
-    } catch (Exception e) {
-      throw new RuntimeException(
-          "Unexpectedly unable to access ForeachStatement.collectionElementType via reflection", e);
-    }
-
-    try {
-      haveReceiverField = ReferenceExpression.class.getDeclaredField("haveReceiver");
-      haveReceiverField.setAccessible(true);
-    } catch (Exception e) {
-      throw new RuntimeException(
-          "Unexpectedly unable to access ReferenceExpression.haveReceiver via reflection", e);
-    }
   }
 
   /**
@@ -4296,4 +4266,51 @@
         return new JMultiExpression(info, incrementsExpressions);
     }
   }
+
+  private boolean hasQualifier(ReferenceExpression x) {
+    return (Boolean) accessPrivateField(JdtPrivateHacks.haveReceiverField, x);
+  }
+
+  private TypeBinding getCollectionElementTypeBinding(ForeachStatement x) {
+  return (TypeBinding) accessPrivateField(JdtPrivateHacks.collectionElementTypeField, x);
+  }
+
+  private Object accessPrivateField(Field field, ASTNode astNode) {
+    try {
+      return field.get(astNode);
+    } catch (IllegalAccessException e) {
+      throw translateException(astNode, e);
+    }
+  }
+
+  static class JdtPrivateHacks {
+    /**
+     * Reflective access to {@link ForeachStatement#collectionElementType}.
+     */
+    private static final Field collectionElementTypeField;
+    /**
+     * Reflective access to {@link ReferenceExpression#haveReceiver}.
+     */
+    private static final Field haveReceiverField;
+
+    static {
+      try {
+        collectionElementTypeField =
+            ForeachStatement.class.getDeclaredField("collectionElementType");
+        collectionElementTypeField.setAccessible(true);
+      } catch (Exception e) {
+        throw new RuntimeException(
+            "Unexpectedly unable to access ForeachStatement.collectionElementType via reflection",
+            e);
+      }
+
+      try {
+        haveReceiverField = ReferenceExpression.class.getDeclaredField("haveReceiver");
+        haveReceiverField.setAccessible(true);
+      } catch (Exception e) {
+        throw new RuntimeException(
+            "Unexpectedly unable to access ReferenceExpression.haveReceiver via reflection", e);
+      }
+    }
+  }
 }
diff --git a/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java b/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
index 913c2b6..87e1e1a 100644
--- a/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
+++ b/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
@@ -1372,13 +1372,63 @@
     assertSame("a", function.f(0, pars));
   }
 
+  private static <T> T m(T s) {
+    return s;
+  }
+
+  static class Some<T> {
+    T s;
+    MyFunction2<T, T ,T> combine;
+    Some(T s, MyFunction2<T, T, T>  combine) {
+      this.s = s;
+      this.combine = combine;
+    }
+    public T m(T s2) {
+      return combine.apply(s, s2);
+    }
+    public T m1() {
+      return s;
+    }
+  }
+
   @FunctionalInterface
-  interface ToString {
-    String apply(StringBuilder t);
+  interface MyFunction1<T, U> {
+    U apply(T t);
+  }
+
+  @FunctionalInterface
+  interface MyFunction2<T, U, V> {
+    V apply(T t, U u);
   }
 
   public void testMethodReferenceImplementedInSuperclass() {
-    ToString toString = StringBuilder::toString;
+    MyFunction1<StringBuilder, String> toString = StringBuilder::toString;
     assertEquals("Hello", toString.apply(new StringBuilder("Hello")));
   }
+
+  static MyFunction2<String, String, String> concat = (s,t) -> s + t;
+
+  public void testMethodReferenceWithGenericTypeParameters() {
+    testMethodReferencesWithGenericTypeParameters(
+        new Some<String>("Hell", concat), "Hell", "o", concat);
+  }
+
+  private static <T> void testMethodReferencesWithGenericTypeParameters(
+      Some<T> some, T t1, T t2, MyFunction2<T, T, T> combine) {
+    T t1t2 = combine.apply(t1, t2);
+
+    // Test all 4 flavours of methodReference
+    // 1. Static method
+    assertEquals(t1t2, ((MyFunction1<T, T>) Java8Test::m).apply(t1t2));
+    // 2. Qualified instance method
+    assertEquals(t1t2, ((MyFunction1<T, T>) some::m).apply(t2));
+    // 3. Unqualified instance method
+    assertEquals(t1, ((MyFunction1<Some<T>, T>) Some<T>::m1).apply(some));
+    assertEquals("Hello",
+        ((MyFunction1<Some<String>, String>)
+              Some<String>::m1).apply(new Some<>("Hello", concat)));
+    // 4. Constructor reference.
+    assertEquals(t1t2,
+        ((MyFunction2<T, MyFunction2<T, T, T>, Some<T>>) Some<T>::new).apply(t1t2, combine).m1());
+  }
 }
diff --git a/user/test/com/google/gwt/dev/jjs/test/Java8Test.java b/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
index 2f0a8f4..4370746 100644
--- a/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
+++ b/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
@@ -272,7 +272,11 @@
     assertFalse(isGwtSourceLevel8());
   }
 
-  private boolean isGwtSourceLevel8() {
+  public void testMethodReferenceWithGenericTypeParameters() {
+    assertFalse(isGwtSourceLevel8());
+  }
+
+    private boolean isGwtSourceLevel8() {
     return JUnitShell.getCompilerOptions().getSourceLevel().compareTo(SourceLevel.JAVA8) >= 0;
   }
 }
\ No newline at end of file