mar*_*ine 9 scala tail-recursion
我在向朋友解释我期望Scala中的非尾递归函数比尾递归函数慢,所以我决定验证它.我用两种方式编写了一个很好的旧因子函数,并试图比较结果.这是代码:
def main(args: Array[String]): Unit = {
val N = 2000 // not too much or else stackoverflows
var spent1: Long = 0
var spent2: Long = 0
for ( i <- 1 to 100 ) { // repeat to average the results
val t0 = System.nanoTime
factorial(N)
val t1 = System.nanoTime
tailRecFact(N)
val t2 = System.nanoTime
spent1 += t1 - t0
spent2 += t2 - t1
}
println(spent1/1000000f) // get milliseconds
println(spent2/1000000f)
}
@tailrec
def tailRecFact(n: BigInt, s: BigInt = 1): BigInt = if (n == 1) s else tailRecFact(n - 1, s * n)
def factorial(n: BigInt): BigInt = if (n == 1) 1 else n * factorial(n - 1)
Run Code Online (Sandbox Code Playgroud)
结果令我困惑,我得到这种输出:
578.2985
870.22125
这意味着非尾递归函数比尾递归函数快30%,并且操作次数相同!
什么能解释这些结果?
它实际上不是你第一次看的地方.原因在于你的尾递归方法,你正在用它的乘法做更多的工作.尝试在递归调用中交换参数n和s的顺序,它会均匀.
def tailRecFact(n: BigInt, s: BigInt): BigInt = if (n == 1) s else tailRecFact(n - 1, n * s)
Run Code Online (Sandbox Code Playgroud)
此外,此示例中的大部分时间都采用BigInt操作,这使得递归调用的时间相形见绌.如果我们将这些转换为Ints(编译为Java原语),那么您可以看到尾递归(goto)与方法调用的比较.
object Test extends App {
val N = 2000
val t0 = System.nanoTime()
for ( i <- 1 to 1000 ) {
factorial(N)
}
val t1 = System.nanoTime
for ( i <- 1 to 1000 ) {
tailRecFact(N, 1)
}
val t2 = System.nanoTime
println((t1 - t0) / 1000000f) // get milliseconds
println((t2 - t1) / 1000000f)
def factorial(n: Int): Int = if (n == 1) 1 else n * factorial(n - 1)
@tailrec
final def tailRecFact(n: Int, s: Int): Int = if (n == 1) s else tailRecFact(n - 1, s * n)
}
95.16733
3.987605
Run Code Online (Sandbox Code Playgroud)
为了兴趣,反编译输出
public final scala.math.BigInt tailRecFact(scala.math.BigInt, scala.math.BigInt);
Code:
0: aload_1
1: iconst_1
2: invokestatic #16 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
5: invokestatic #20 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
8: ifeq 13
11: aload_2
12: areturn
13: aload_1
14: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
17: iconst_1
18: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
21: invokevirtual #36 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
24: aload_1
25: aload_2
26: invokevirtual #39 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
29: astore_2
30: astore_1
31: goto 0
public scala.math.BigInt factorial(scala.math.BigInt);
Code:
0: aload_1
1: iconst_1
2: invokestatic #16 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
5: invokestatic #20 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
8: ifeq 21
11: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
14: iconst_1
15: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
18: goto 40
21: aload_1
22: aload_0
23: aload_1
24: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
27: iconst_1
28: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
31: invokevirtual #36 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
34: invokevirtual #47 // Method factorial:(Lscala/math/BigInt;)Lscala/math/BigInt;
37: invokevirtual #39 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
40: areturn
Run Code Online (Sandbox Code Playgroud)
除了@monkjack所显示的问题(即乘以小*大比大*小,这确实占了差异的更大块),你的算法在每种情况下都是不同的,所以它们不是真正可比的.
在尾递归版本中,你是大到小的:
n * n-1 * n-2 * ... * 2 * 1
Run Code Online (Sandbox Code Playgroud)
在非尾递归版本中,您将从小到大倍增:
n * (n-1 * (n-2 * (... * (2 * 1))))
Run Code Online (Sandbox Code Playgroud)
如果你改变了尾递归版本,那么它会从小到大:
def tailRecFact2(n: BigInt) = {
def loop(x: BigInt, out: BigInt): BigInt =
if (x > n) out else loop(x + 1, x * out)
loop(1, 1)
}
Run Code Online (Sandbox Code Playgroud)
然后尾部递归比正常递归快20%,而不是像你刚刚进行monkjack校正那样慢10%.这是因为将小BigInts放在一起比将大BigI倍增要快.