Method references should correctly implement their intersection cast

Bug: #9653
Change-Id: Icbbc62154fccd673f6eaf9d9c2a55847a58d43a5
Bug-Link: https://github.com/gwtproject/gwt/issues/9653
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java b/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
index 656fc3a..240860b 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
@@ -1214,14 +1214,7 @@
       // And its JInterface container we must implement
       // There may be more than more JInterface containers to be implemented
       // if the lambda expression is cast to a IntersectionCastType.
-      JInterfaceType[] lambdaInterfaces;
-      if (binding instanceof IntersectionTypeBinding18) {
-        IntersectionTypeBinding18 type = (IntersectionTypeBinding18) binding;
-        lambdaInterfaces =
-            processIntersectionType(type, new JInterfaceType[type.intersectingTypes.length]);
-      } else {
-        lambdaInterfaces = new JInterfaceType[] {(JInterfaceType) typeMap.get(binding)};
-      }
+      JInterfaceType[] lambdaInterfaces = getInterfacesToImplement(binding);
       SourceInfo info = makeSourceInfo(x);
 
       // Create an inner class to implement the interface and SAM method.
@@ -1270,6 +1263,14 @@
       newTypes.add(innerLambdaClass);
     }
 
+    private JInterfaceType[] getInterfacesToImplement(TypeBinding binding) {
+      if (binding instanceof IntersectionTypeBinding18) {
+        IntersectionTypeBinding18 type = (IntersectionTypeBinding18) binding;
+        return processIntersectionType(type, new JInterfaceType[type.intersectingTypes.length]);
+      }
+      return new JInterfaceType[]{(JInterfaceType) typeMap.get(binding)};
+    }
+
     private void createFunctionalExpressionBridges(
         JClassType functionalExpressionImplementationClass,
         FunctionalExpression functionalExpression,
@@ -1771,7 +1772,8 @@
           binding.getSingleAbstractMethod(blockScope, false).original();
       // Get the interface method is binds to
       JMethod interfaceMethod = typeMap.get(declarationSamBinding);
-      JInterfaceType funcType = (JInterfaceType) typeMap.get(binding);
+
+      JInterfaceType[] funcType = getInterfacesToImplement(binding);
       SourceInfo info = makeSourceInfo(x);
 
       // Get the method that the Type::method is actually referring to
diff --git a/dev/core/test/com/google/gwt/dev/jjs/impl/Java8AstTest.java b/dev/core/test/com/google/gwt/dev/jjs/impl/Java8AstTest.java
index a85cb10..ac51d85 100644
--- a/dev/core/test/com/google/gwt/dev/jjs/impl/Java8AstTest.java
+++ b/dev/core/test/com/google/gwt/dev/jjs/impl/Java8AstTest.java
@@ -1117,7 +1117,7 @@
         formatSource(samMethod.toSource()));
   }
 
-  public void testIntersectionCastMultipleAbstractMethods() throws Exception {
+  public void testIntersectionCastOfLambdaMultipleAbstractMethods() throws Exception {
     addSnippetClassDecl("interface I1 { public void foo(); }");
     addSnippetClassDecl("interface I2 { public void foo(); }");
     String lambda = "Object o = (I1 & I2) () -> {};";
@@ -1140,6 +1140,166 @@
         formatSource(samMethod.toSource()));
   }
 
+  public void testIntersectionCastOfLambdaMultipleAbstractMethodsWithGenerics() throws Exception {
+    addSnippetClassDecl("interface I1 extends I2<String> { public void foo(String arg0); }");
+    addSnippetClassDecl("interface I2<T> { public void foo(T arg); }");
+    String lambda = "Object o = (I1 & I2<String>) str -> {};";
+    assertEqualBlock("Object o=(EntryPoint$I1)new EntryPoint$lambda$0$Type();", lambda);
+
+    JProgram program = compileSnippet("void", lambda, false);
+
+    assertNotNull(getMethod(program, "lambda$0"));
+
+    JClassType lambdaInnerClass = (JClassType) getType(program, "test.EntryPoint$lambda$0$Type");
+    assertNotNull(lambdaInnerClass);
+    assertEquals("java.lang.Object", lambdaInnerClass.getSuperClass().getName());
+    assertEquals(1, lambdaInnerClass.getImplements().size());
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I1")));
+    // should implement foo method
+    JMethod samMethod = findMethod(lambdaInnerClass, "foo(Ljava/lang/String;)V");
+    assertEquals("public final void foo(String arg0){EntryPoint.lambda$0(arg0);}",
+        formatSource(samMethod.toSource()));
+  }
+  public void testIntersectionCastOfMethodReference() throws Exception {
+    addSnippetClassDecl("static class C { public static void go() {} }");
+    addSnippetClassDecl("interface I1 { public void foo(); }");
+    addSnippetClassDecl("interface I2 { }");
+    String methodReference = "Object o = (I2 & I1) C::go;";
+    assertEqualBlock("Object o=(EntryPoint$I1)(EntryPoint$I2)new EntryPoint$0methodref$go$Type();",
+        methodReference);
+    JProgram program = compileSnippet("void", methodReference, false);
+
+    // created by GwtAstBuilder
+    JClassType lambdaInnerClass = (JClassType) getType(program, "test.EntryPoint$0methodref$go$Type");
+    assertNotNull(lambdaInnerClass);
+
+    // no fields
+    assertEquals(0, lambdaInnerClass.getFields().size());
+
+    // should have constructor taking no args
+    JMethod ctor = findMethod(lambdaInnerClass, "EntryPoint$0methodref$go$Type");
+    assertTrue(ctor instanceof JConstructor);
+    assertEquals(0, ctor.getParams().size());
+
+    // should implements I1 and I2
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I1")));
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I2")));
+    // should implement foo method
+    JMethod samMethod = findMethod(lambdaInnerClass, "foo");
+    assertEquals("public final void foo(){EntryPoint$C.go();}",
+        formatSource(samMethod.toSource()));
+  }
+
+  public void testMultipleIntersectionCastOfMethodReference() throws Exception {
+    addSnippetClassDecl("static class C { public static void go() {} }");
+    addSnippetClassDecl("interface I1 { public void foo(); }");
+    addSnippetClassDecl("interface I2 { }");
+    addSnippetClassDecl("interface I3 { }");
+    String methodReference = "I2 o = (I3 & I2 & I1) C::go;";
+    assertEqualBlock(
+        "EntryPoint$I2 o=(EntryPoint$I1)(EntryPoint$I2)(EntryPoint$I3)new EntryPoint$0methodref$go$Type();",
+        methodReference);
+
+    JProgram program = compileSnippet("void", methodReference, false);
+
+    // created by GwtAstBuilder
+    JClassType lambdaInnerClass = (JClassType) getType(program, "test.EntryPoint$0methodref$go$Type");
+    assertNotNull(lambdaInnerClass);
+
+    // no fields
+    assertEquals(0, lambdaInnerClass.getFields().size());
+
+    // should have constructor taking no args
+    JMethod ctor = findMethod(lambdaInnerClass, "EntryPoint$0methodref$go$Type");
+    assertTrue(ctor instanceof JConstructor);
+    assertEquals(0, ctor.getParams().size());
+
+    // should extends java.lang.Object, implements I1, I2 and I3
+    assertEquals("java.lang.Object", lambdaInnerClass.getSuperClass().getName());
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I1")));
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I2")));
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I3")));
+    // should implement foo method
+    JMethod samMethod = findMethod(lambdaInnerClass, "foo");
+    assertEquals("public final void foo(){EntryPoint$C.go();}",
+        formatSource(samMethod.toSource()));
+  }
+
+  public void testIntersectionCastOfMethodReferenceOneAbstractMethod() throws Exception {
+    addSnippetClassDecl("static class C { public static void go() {} }");
+    addSnippetClassDecl("interface I1 { public void foo(); }");
+    addSnippetClassDecl("interface I2 extends I1{ public void foo();}");
+    String lambda = "Object o = (I1 & I2) C::go;";
+    // (I1 & I2) is resolved to I2 by JDT.
+    assertEqualBlock("Object o=(EntryPoint$I2)new EntryPoint$0methodref$go$Type();",
+        lambda);
+
+    JProgram program = compileSnippet("void", lambda, false);
+
+    JClassType lambdaInnerClass = (JClassType) getType(program, "test.EntryPoint$0methodref$go$Type");
+    assertNotNull(lambdaInnerClass);
+    assertEquals("java.lang.Object", lambdaInnerClass.getSuperClass().getName());
+    assertEquals(1, lambdaInnerClass.getImplements().size()); // only implements I2.
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I2")));
+    // should implement foo method
+    JMethod samMethod = findMethod(lambdaInnerClass, "foo");
+    assertEquals("public final void foo(){EntryPoint$C.go();}",
+        formatSource(samMethod.toSource()));
+  }
+
+  public void testIntersectionCastOfMethodReferenceMultipleAbstractMethods() throws Exception {
+    addSnippetClassDecl("static class C { public static void go() {} }");
+    addSnippetClassDecl("interface I1 { public void foo(); }");
+    addSnippetClassDecl("interface I2 { public void foo(); }");
+    String methodReference = "Object o = (I1 & I2) C::go;";
+    assertEqualBlock("Object o=(EntryPoint$I1)(EntryPoint$I2)new EntryPoint$0methodref$go$Type();",
+        methodReference);
+
+    JProgram program = compileSnippet("void", methodReference, false);
+
+    JClassType lambdaInnerClass = (JClassType) getType(program, "test.EntryPoint$0methodref$go$Type");
+    assertNotNull(lambdaInnerClass);
+    assertEquals("java.lang.Object", lambdaInnerClass.getSuperClass().getName());
+    assertEquals(2, lambdaInnerClass.getImplements().size());
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I1")));
+  assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I2")));
+    // should implement foo method
+    JMethod samMethod = findMethod(lambdaInnerClass, "foo");
+    assertEquals("public final void foo(){EntryPoint$C.go();}",
+        formatSource(samMethod.toSource()));
+  }
+
+  public void testIntersectionCastOfMethodReferenceMultipleAbstractMethodsWithGenerics() throws Exception {
+    addSnippetClassDecl("static class C { public static void go(String arg) {} }");
+    addSnippetClassDecl("interface I1 extends I2<String> { public void foo(String arg); }");
+    addSnippetClassDecl("interface I2<T> { public void foo(T arg); }");
+    String methodReference = "Object o = (I1 & I2<String>) C::go;";
+    assertEqualBlock("Object o=(EntryPoint$I1)new EntryPoint$0methodref$go$Type();",
+        methodReference);
+
+    JProgram program = compileSnippet("void", methodReference, false);
+
+    JClassType lambdaInnerClass = (JClassType) getType(program, "test.EntryPoint$0methodref$go$Type");
+    assertNotNull(lambdaInnerClass);
+    assertEquals("java.lang.Object", lambdaInnerClass.getSuperClass().getName());
+    assertEquals(1, lambdaInnerClass.getImplements().size());
+    assertTrue(
+        lambdaInnerClass.getImplements().contains(program.getFromTypeMap("test.EntryPoint$I1")));
+    // should implement foo method
+    JMethod samMethod = findMethod(lambdaInnerClass, "foo(Ljava/lang/String;)V");
+    assertEquals("public final void foo(String arg){EntryPoint$C.go(arg);}",
+        formatSource(samMethod.toSource()));
+  }
+
   private static final MockJavaResource LAMBDA_METAFACTORY =
       JavaResourceBase.createMockJavaResource("java.lang.invoke.LambdaMetafactory",
           "package java.lang.invoke;",
diff --git a/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java b/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
index 907d9f3..a251c73 100644
--- a/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
+++ b/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
@@ -1617,7 +1617,7 @@
     assertEquals(4, outer.createInner2Param().apply(1, 2).sum);
     assertEquals(7, outer.createInner3Param().apply(1, 2, 3).sum);
     assertEquals(7, outer.createInner2ParamArray().apply(1, new Integer[] {2, 3}).sum);
-    
+
     // inner class constructor varargs + autoboxing
     assertEquals(2, outer.createInner1IntParam().apply(1).sum);
     assertEquals(4, outer.createInner2IntParam().apply(1, 2).sum);
@@ -2006,7 +2006,7 @@
   ////////////////////////////////////////////////////////////
   //
   //   Tests for language features introduced in Java 9
-  
+
   class Resource implements AutoCloseable {
     boolean isOpen = true;
 
@@ -2080,4 +2080,37 @@
     Predicate p = o -> true;
     assertTrue(p.test(null));
   }
+
+  interface I2<T> { public T foo(T arg); }
+
+  interface I1 extends I2<String> { public String foo(String arg0); }
+
+  @SuppressWarnings({"rawtypes", "unchecked"})
+  public void testIntersectionCastLambda() {
+
+    Object instance = (I1 & I2<String>) val -> "#" + val;
+
+    assertTrue(instance instanceof I1);
+    assertTrue(instance instanceof I2);
+
+    I1 lambda = (I1) instance;
+    I2 raw = lambda;
+    assertEquals("#1", raw.foo("1")); // tests that the bridge exists and is correct
+    assertEquals("#2", lambda.foo("2"));
+  }
+
+  static class C { public static String append(String str) { return "#" + str; } }
+  @SuppressWarnings({"rawtypes", "unchecked"})
+  public void testIntersectionCastMethodReference() {
+
+    Object instance = (I1 & I2<String>) C::append;
+
+    assertTrue(instance instanceof I1);
+    assertTrue(instance instanceof I2);
+
+    I1 lambda = (I1) instance;
+    I2 raw = lambda;
+    assertEquals("#1", raw.foo("1")); // tests that the bridge exists and is correct
+    assertEquals("#2", lambda.foo("2"));
+  }
 }
diff --git a/user/test/com/google/gwt/dev/jjs/test/Java8Test.java b/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
index 312d70c..5e4c773 100644
--- a/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
+++ b/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
@@ -360,6 +360,14 @@
     assertFalse(isGwtSourceLevel9());
   }
 
+  public void testIntersectionCastLambda() {
+    assertFalse(isGwtSourceLevel9());
+  }
+
+  public void testIntersectionCastMethodReference() {
+    assertFalse(isGwtSourceLevel9());
+  }
+
   private boolean isGwtSourceLevel9() {
     return JUnitShell.getCompilerOptions().getSourceLevel().compareTo(SourceLevel.JAVA9) >= 0;
   }