Scala:使用 foldRight 实现 flatMap

det*_*jan 1 functional-programming scala

我无法理解函数式编程练习的解决方案:

实现flatMap只使用foldRightNil::(缺点)。

解决方法如下:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = 
   xs.foldRight(List[B]())((outCurr, outAcc) =>
   f(outCurr).foldRight(outAcc)((inCurr, inAcc) => inCurr :: inAcc))
Run Code Online (Sandbox Code Playgroud)

我试图将匿名函数分解为函数定义,以将解决方案重写为不走运。我无法理解正在发生的事情或想办法将其分解以使其不那么复杂。因此,任何有关解决方案的帮助或解释将不胜感激。

谢谢!

Dat*_*yen 5

首先,忽略约束并考虑flatMap这种情况下的功能。你有一个List[A]和一个函数f: A => List[B]。通常,如果您只是map在列表上执行 a并应用该f功能,您会得到 a List[List[B]],对吗?所以要得到一个List[B],你会怎么做?你会foldRightList[List[B]]找回一个List[B]仅通过附加在所有元素List[List[B]]。所以代码看起来有点像这样:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
    val tmp = xs.map(f) // List[List[B]]
    tmp.foldRight(List[B]())((outCurr, outAcc) => outCurr ++ outAcc)
}
Run Code Online (Sandbox Code Playgroud)

为了验证我们到目前为止所拥有的,在 REPL 中运行代码并根据内置flatMap方法验证结果:

scala> def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
 |     val tmp = xs.map(f) // List[List[B]]
 |     tmp.foldRight(List[B]())((outCurr, outAcc) => outCurr ++ outAcc)
 | }
flatMap: [A, B](xs: List[A])(f: A => List[B])List[B]

scala> flatMap(List(1, 2, 3))(i => List(i, 2*i, 3*i))
res0: List[Int] = List(1, 2, 3, 2, 4, 6, 3, 6, 9)

scala> List(1,2,3).flatMap(i => List(i, 2*i, 3*i))
res1: List[Int] = List(1, 2, 3, 2, 4, 6, 3, 6, 9)
Run Code Online (Sandbox Code Playgroud)

好的,那么现在,看看我们的约束,我们map这里是不允许使用的。但是我们真的不需要,因为map这里只是用于遍历 list xs。然后我们可以foldRight用于同样的目的。因此,让我们map使用foldRight以下代码重写该部分:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
    val tmp = xs.foldRight(List[List[B]]())((curr, acc) => f(curr) :: acc) // List[List[B]]
    tmp.foldRight(List[B]())((outCurr, outAcc) => outCurr ++ outAcc)
}
Run Code Online (Sandbox Code Playgroud)

好的,让我们验证一下新代码:

scala> def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
     |         val tmp = xs.foldRight(List[List[B]]())((curr, acc) => f(curr) :: acc) // List[List[B]]
     |         tmp.foldRight(List[B]())((outCurr, outAcc) => outCurr ++ outAcc)
     |     }
flatMap: [A, B](xs: List[A])(f: A => List[B])List[B]

scala> flatMap(List(1, 2, 3))(i => List(i, 2*i, 3*i))
res3: List[Int] = List(1, 2, 3, 2, 4, 6, 3, 6, 9)
Run Code Online (Sandbox Code Playgroud)

好的,到目前为止一切顺利。因此,让我们稍微优化一下代码,不是将两个foldRight按顺序排列,而是将它们合并为一个foldRight。这不应该太难:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
    xs.foldRight(List[B]()) { (curr, acc) => // Note: acc is List[B]
        val tmp2 = f(curr) // List[B]
        tmp2 ++ acc
    }
}
Run Code Online (Sandbox Code Playgroud)

再次验证:

scala> def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
     |     xs.foldRight(List[B]()) { (curr, acc) => // Note: acc is List[B]
     |         val tmp2 = f(curr) // List[B]
     |         tmp2 ++ acc
     |     }
     | }
flatMap: [A, B](xs: List[A])(f: A => List[B])List[B]

scala> flatMap(List(1, 2, 3))(i => List(i, 2*i, 3*i))
res4: List[Int] = List(1, 2, 3, 2, 4, 6, 3, 6, 9)
Run Code Online (Sandbox Code Playgroud)

好的,那么我们来看看我们的约束,看起来我们不能使用++操作。好吧,++这只是将两者附加List[B]在一起的一种方式,因此我们当然可以使用foldRight方法来实现相同的事情,如下所示:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
    xs.foldRight(List[B]()) { (curr, acc) => // Note: acc is List[B]
        val tmp2 = f(curr) // List[B]
        tmp2.foldRight(acc)((inCurr, inAcc) => inCurr :: inAcc)
    }
}
Run Code Online (Sandbox Code Playgroud)

然后,我们可以通过以下方式将它们全部合并为一行:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = 
   xs.foldRight(List[B]())((curr, acc) =>
   f(curr).foldRight(acc)((inCurr, inAcc) => inCurr :: inAcc))
Run Code Online (Sandbox Code Playgroud)

这不是给定的答案:)