Make sure lambdas box, unbox and insert erasure casts when necessary.

Bug: #9558
Bug-Link: https://github.com/gwtproject/gwt/issues/9558
Change-Id: I2b18b0aab09c8ad1a31bc53041528668059f8067
diff --git a/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java b/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
index eafcf3e..ad4fb56 100644
--- a/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
+++ b/dev/core/src/com/google/gwt/dev/jjs/impl/GwtAstBuilder.java
@@ -1299,15 +1299,21 @@
 
       // and add any locals that were storing captured outer variables as arguments to the call
       // first
+      int samArg = 0;
       for (JField localField : locals) {
-        samCall.addArg(new JFieldRef(info, new JThisRef(info, innerLambdaClass),
-            localField, innerLambdaClass));
+        JType samArgumentType = lambdaMethod.getParams().get(samArg).getType();
+        JExpression capture = new JFieldRef(info, new JThisRef(info, innerLambdaClass),
+            localField, innerLambdaClass);
+        samCall.addArg(maybeInsertCasts(capture, samArgumentType));
+        samArg++;
       }
 
       // and now we propagate the rest of the actual interface method parameters on the end
       // (e.g. ClickEvent e)
       for (JParameter param : samMethod.getParams()) {
-        samCall.addArg(param.makeRef(info));
+        JType samArgumentType = lambdaMethod.getParams().get(samArg).getType();
+        samCall.addArg(maybeInsertCasts(param.makeRef(info), samArgumentType));
+        samArg++;
       }
 
       // we either add a return statement, or don't, depending on what the interface wants
@@ -1930,12 +1936,12 @@
             || !referredMethodBinding.isVarargs()
             || (paramNumber < varArg)) {
           destParam = referredMethodBinding.parameters[paramNumber];
-          paramExpr = boxOrUnboxExpression(paramExpr, samParameterBinding, destParam);
+          paramExpr = maybeInsertCasts(paramExpr, samParameterBinding, destParam);
           samCall.addArg(paramExpr);
         } else if (!samParameterBinding.isArrayType()) {
           // else add trailing parameters to var-args initializer list for an array
           destParam = referredMethodBinding.parameters[varArg].leafComponentType();
-          paramExpr = boxOrUnboxExpression(paramExpr, samParameterBinding, destParam);
+          paramExpr = maybeInsertCasts(paramExpr, samParameterBinding, destParam);
           varArgInitializers.add(paramExpr);
         }
         paramNumber++;
@@ -1954,7 +1960,7 @@
       // TODO(rluble): Make this a call to JjsUtils.makeMethodEndStatement once boxing/unboxing
       // is handled there.
       if (samMethod.getType() != JPrimitiveType.VOID) {
-        JExpression samExpression = boxOrUnboxExpression(samCall, referredMethodBinding.returnType,
+        JExpression samExpression = maybeInsertCasts(samCall, referredMethodBinding.returnType,
             declarationSamBinding.returnType);
         samMethodBody.getBlock().addStmt(maybeBoxOrUnbox(samExpression, x).makeReturnStatement());
       } else {
@@ -1995,7 +2001,10 @@
       push(allocLambda);
     }
 
-    private JExpression boxOrUnboxExpression(JExpression expr, TypeBinding fromType,
+    /**
+     * Inserts necessary casts for boxing, unboxing or erasure reasons if needed.
+     */
+    private JExpression maybeInsertCasts(JExpression expr, TypeBinding fromType,
         TypeBinding toType) {
       if (fromType == TypeBinding.VOID || toType == TypeBinding.VOID) {
         return expr;
@@ -2016,6 +2025,16 @@
       return new JCastOperation(expr.getSourceInfo(), typeMap.get(castToType), expr);
     }
 
+    /**
+     * Inserts necessary casts for boxing, unboxing or erasure reasons if needed.
+     */
+    private JExpression maybeInsertCasts(JExpression expr, JType toType) {
+      if (expr.getType() == toType) {
+        return expr;
+      }
+      return new JCastOperation(expr.getSourceInfo(), toType, expr);
+    }
+
     @Override
     public void endVisit(ReturnStatement x, BlockScope scope) {
       try {
diff --git a/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java b/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
index e9c7aec..e084910 100644
--- a/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
+++ b/user/test-super/com/google/gwt/dev/jjs/super/com/google/gwt/dev/jjs/test/Java8Test.java
@@ -16,6 +16,8 @@
 package com.google.gwt.dev.jjs.test;
 
 import com.google.gwt.core.client.GwtScriptOnly;
+import com.google.gwt.core.client.JavaScriptObject;
+import com.google.gwt.core.client.JsonUtils;
 import com.google.gwt.dev.jjs.test.defaultmethods.ImplementsWithDefaultMethodAndStaticInitializer;
 import com.google.gwt.dev.jjs.test.defaultmethods.SomeClass;
 import com.google.gwt.junit.client.GWTTestCase;
@@ -23,6 +25,9 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.function.BiFunction;
+import java.util.function.IntFunction;
+import java.util.stream.Collectors;
 
 import jsinterop.annotations.JsFunction;
 import jsinterop.annotations.JsOverlay;
@@ -1931,5 +1936,75 @@
     java.util.function.Function<String[], String> function = Java8Test::first;
     assertEquals("Hello", function.apply(new String[] {"Hello", "GoodBye"}));
   }
-}
 
+  interface SingleJsoImplA {
+    String getAData();
+
+    List<SingleJsoImplB> getListOfB();
+  }
+
+  interface SingleJsoImplB {
+    String getBData();
+  }
+
+  private static final class AOverlay extends JavaScriptObject implements SingleJsoImplA {
+    protected AOverlay() { }
+
+    @Override
+    public native String getAData() /*-{
+      return this.data;
+    }-*/;
+
+    @Override
+    public native List<SingleJsoImplB> getListOfB() /*-{
+      return @java.util.Arrays::asList(*)(this.listOfb);
+    }-*/;
+  }
+
+  private static final class BOverlay extends JavaScriptObject implements SingleJsoImplB {
+    protected BOverlay() { }
+
+    @Override
+    public native String getBData() /*-{
+      return this.data;
+    }-*/;
+  }
+
+  private static SingleJsoImplA createA() {
+    return JsonUtils.safeEval(
+        "{\"data\":\"a value\",\"listOfb\":[{\"data\":\"b1\"},{\"data\":\"b2\"}]}");
+  }
+
+  // Regression for issue #9558
+  public void testJSOLivenessSingleImplErasure() {
+    SingleJsoImplA a = createA();
+    String result = a.getListOfB().stream()
+        .map(SingleJsoImplB::getBData).collect(Collectors.joining(","));
+    assertEquals("b1,b2", result);
+    result = a.getListOfB().stream()
+        .map(b -> b.getBData()).collect(Collectors.joining(","));
+    assertEquals("b1,b2", result);
+  }
+
+  @SuppressWarnings({"rawtypes", "unchecked"})
+  public void testLambdaErasureCasts() {
+    List list = new ArrayList<String>();
+    list.add("2");
+    try {
+      ((List<Integer>) list).stream().map(n -> n.intValue() == 2).findAny();
+      fail("Should have thrown.");
+    } catch (ClassCastException expected) {
+    }
+  }
+
+  public void testLambdaBoxing() {
+    BiFunction<Integer, Integer, Boolean> equals = (i, j) -> i + 0 == j;
+    assertTrue(equals.apply(1,1));
+    assertTrue(equals.apply(new Integer(2),2));
+    assertTrue(equals.apply(new Integer(3), new Integer(3)));
+
+    IntFunction<Integer> unboxBox = i -> i;
+    assertEquals(2, (int) unboxBox.apply(2));
+    assertEquals(2, (int) unboxBox.apply(new Integer(2)));
+  }
+}
diff --git a/user/test/com/google/gwt/dev/jjs/test/Java8Test.java b/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
index 1228184..b63929b 100644
--- a/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
+++ b/user/test/com/google/gwt/dev/jjs/test/Java8Test.java
@@ -328,6 +328,18 @@
     assertFalse(isGwtSourceLevel8());
   }
 
+  public void testJSOLivenessSingleImplErasure() {
+    assertFalse(isGwtSourceLevel8());
+  }
+
+  public void testLambdaErasureCasts() {
+    assertFalse(isGwtSourceLevel8());
+  }
+
+  public void testLambdaBoxing() {
+    assertFalse(isGwtSourceLevel8());
+  }
+
   private boolean isGwtSourceLevel8() {
     return JUnitShell.getCompilerOptions().getSourceLevel().compareTo(SourceLevel.JAVA8) >= 0;
   }