PySpark; DecimalType 乘法精度损失

blu*_*blu 9 python apache-spark pyspark

使用 PySpark 进行乘法运算时,PySpark 似乎正在失去精度。

例如,当精度为 38,10 的两位小数相乘时,它返回 38,6 并四舍五入到三位小数,这是不正确的结果。

from decimal import Decimal
from pyspark.sql.types import DecimalType, StructType, StructField

schema = StructType([StructField("amount", DecimalType(38,10)), StructField("fx", DecimalType(38,10))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)

df.printSchema()
df = df.withColumn("amount_usd", df.amount * df.fx)
df.printSchema()
df.show()
Run Code Online (Sandbox Code Playgroud)

结果

>>> df.printSchema()
root
 |-- amount: decimal(38,10) (nullable = true)
 |-- fx: decimal(38,10) (nullable = true)
 |-- amount_usd: decimal(38,6) (nullable = true)

>>> df = df.withColumn("amount_usd", df.amount * df.fx)
>>> df.printSchema()
root
 |-- amount: decimal(38,10) (nullable = true)
 |-- fx: decimal(38,10) (nullable = true)
 |-- amount_usd: decimal(38,6) (nullable = true)

>>> df.show()
+--------------+------------+----------+
|        amount|          fx|amount_usd|
+--------------+------------+----------+
|233.0000000000|1.1403218880|265.695000|
+--------------+------------+----------+

Run Code Online (Sandbox Code Playgroud)

这是一个错误吗?有没有办法得到正确的结果?

niu*_*uer 7

我认为这是预期的行为。

Spark 的 Catalyst 引擎将用输入语言(例如 Python)编写的表达式转换为 Spark 内部 Catalyst 表示相同类型信息的表达式。然后它将对该内部表示进行操作。

如果您sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scalaspark 的源代码中查看该文件,它用于:

计算和传播固定精度小数的精度。

 * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
 * respectively, then the following operations have the following precision / scale:
 *   Operation    Result Precision                        Result Scale
 *   ------------------------------------------------------------------------
 *   e1 * e2      p1 + p2 + 1                             s1 + s2
Run Code Online (Sandbox Code Playgroud)

现在让我们看看乘法的代码。adjustPrecisionScale调用函数的地方:

    case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
      val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
        DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
      } else {
        DecimalType.bounded(p1 + p2 + 1, s1 + s2)
      }
      val widerType = widerDecimalType(p1, s1, p2, s2)
      CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
        resultType, nullOnOverflow)
Run Code Online (Sandbox Code Playgroud)

adjustPrecisionScale是魔法发生的地方,我在这里粘贴了函数,所以你可以看到逻辑

  private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
    // Assumption:
    assert(precision >= scale)

    if (precision <= MAX_PRECISION) {
      // Adjustment only needed when we exceed max precision
      DecimalType(precision, scale)
    } else if (scale < 0) {
      // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
      // loss since we would cause a loss of digits in the integer part.
      // In this case, we are likely to meet an overflow.
      DecimalType(MAX_PRECISION, scale)
    } else {
      // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
      val intDigits = precision - scale
      // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
      // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
      val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
      // The resulting scale is the maximum between what is available without causing a loss of
      // digits for the integer part of the decimal and the minimum guaranteed scale, which is
      // computed above
      val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)

      DecimalType(MAX_PRECISION, adjustedScale)
    }
  }

Run Code Online (Sandbox Code Playgroud)

现在让我们来看看你的例子,我们有

e1 = Decimal(233.00)
e2 = Decimal(1.1403218880)
Run Code Online (Sandbox Code Playgroud)

每个都有precision = 38scale = 10、 、p1=p2=38s1=s2=10。这两者的乘积应具有precision = p1+p2+1 = 77,并且scale = s1 + s2 = 20

注意,这里MAX_PRECISION=38MINIMUM_ADJUSTED_SCALE=6这里。

所以p1+p2+1=77 > 38val intDigits = precision - scale = 77 - 20 = 57 minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) = min(20, 6) = 6

adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) = max(38-57, 6)=6

最后,precision=38, and scale = 6返回一个 DecimalType with 。这就是为什么您会看到amount_usdis的类型decimal(38,6)

并且在Multiply函数中,DecimalType(38,6)在进行乘法之前,两个数字都已转换为。

如果你运行你的代码Decimal(38,6),即

schema = StructType([StructField("amount", DecimalType(38,6)), StructField("fx", DecimalType(38,6))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)
Run Code Online (Sandbox Code Playgroud)

你会得到

+----------+--------+----------+
|amount    |fx      |amount_usd|
+----------+--------+----------+
|233.000000|1.140322|265.695026|
+----------+--------+----------+
Run Code Online (Sandbox Code Playgroud)

为什么最后的数字是265.695000?这可能是由于Multiply功能中的其他调整。但是你明白了。

Multiply代码可以看出我们在做乘法的时候要避免使用最大精度,如果我们改成18

schema = StructType([StructField("amount", DecimalType(18,10)), StructField("fx", DecimalType(18,10))])

Run Code Online (Sandbox Code Playgroud)

我们得到这个:

+--------------+------------+------------------------+
|amount        |fx          |amount_usd              |
+--------------+------------+------------------------+
|233.0000000000|1.1403218880|265.69499990400000000000|
+--------------+------------+------------------------+
Run Code Online (Sandbox Code Playgroud)

我们得到了对 python 计算结果的更好的近似:

265.6949999039999754657515041
Run Code Online (Sandbox Code Playgroud)

希望这可以帮助!