blu*_*blu 9 python apache-spark pyspark
使用 PySpark 进行乘法运算时,PySpark 似乎正在失去精度。
例如,当精度为 38,10 的两位小数相乘时,它返回 38,6 并四舍五入到三位小数,这是不正确的结果。
from decimal import Decimal
from pyspark.sql.types import DecimalType, StructType, StructField
schema = StructType([StructField("amount", DecimalType(38,10)), StructField("fx", DecimalType(38,10))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)
df.printSchema()
df = df.withColumn("amount_usd", df.amount * df.fx)
df.printSchema()
df.show()
Run Code Online (Sandbox Code Playgroud)
结果
>>> df.printSchema()
root
|-- amount: decimal(38,10) (nullable = true)
|-- fx: decimal(38,10) (nullable = true)
|-- amount_usd: decimal(38,6) (nullable = true)
>>> df = df.withColumn("amount_usd", df.amount * df.fx)
>>> df.printSchema()
root
|-- amount: decimal(38,10) (nullable = true)
|-- fx: decimal(38,10) (nullable = true)
|-- amount_usd: decimal(38,6) (nullable = true)
>>> df.show()
+--------------+------------+----------+
| amount| fx|amount_usd|
+--------------+------------+----------+
|233.0000000000|1.1403218880|265.695000|
+--------------+------------+----------+
Run Code Online (Sandbox Code Playgroud)
这是一个错误吗?有没有办法得到正确的结果?
我认为这是预期的行为。
Spark 的 Catalyst 引擎将用输入语言(例如 Python)编写的表达式转换为 Spark 内部 Catalyst 表示相同类型信息的表达式。然后它将对该内部表示进行操作。
如果您sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala在spark 的源代码中查看该文件,它用于:
计算和传播固定精度小数的精度。
和
* In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
* respectively, then the following operations have the following precision / scale:
* Operation Result Precision Result Scale
* ------------------------------------------------------------------------
* e1 * e2 p1 + p2 + 1 s1 + s2
Run Code Online (Sandbox Code Playgroud)
现在让我们看看乘法的代码。adjustPrecisionScale调用函数的地方:
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType, nullOnOverflow)
Run Code Online (Sandbox Code Playgroud)
adjustPrecisionScale是魔法发生的地方,我在这里粘贴了函数,所以你可以看到逻辑
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
// Assumption:
assert(precision >= scale)
if (precision <= MAX_PRECISION) {
// Adjustment only needed when we exceed max precision
DecimalType(precision, scale)
} else if (scale < 0) {
// Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
// loss since we would cause a loss of digits in the integer part.
// In this case, we are likely to meet an overflow.
DecimalType(MAX_PRECISION, scale)
} else {
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
val intDigits = precision - scale
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
// The resulting scale is the maximum between what is available without causing a loss of
// digits for the integer part of the decimal and the minimum guaranteed scale, which is
// computed above
val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)
DecimalType(MAX_PRECISION, adjustedScale)
}
}
Run Code Online (Sandbox Code Playgroud)
现在让我们来看看你的例子,我们有
e1 = Decimal(233.00)
e2 = Decimal(1.1403218880)
Run Code Online (Sandbox Code Playgroud)
每个都有precision = 38、scale = 10、 、p1=p2=38和s1=s2=10。这两者的乘积应具有precision = p1+p2+1 = 77,并且scale = s1 + s2 = 20
注意,这里MAX_PRECISION=38和MINIMUM_ADJUSTED_SCALE=6这里。
所以p1+p2+1=77 > 38,
val intDigits = precision - scale = 77 - 20 = 57
minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) = min(20, 6) = 6
adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) = max(38-57, 6)=6
最后,precision=38, and scale = 6返回一个 DecimalType with 。这就是为什么您会看到amount_usdis的类型decimal(38,6)。
并且在Multiply函数中,DecimalType(38,6)在进行乘法之前,两个数字都已转换为。
如果你运行你的代码Decimal(38,6),即
schema = StructType([StructField("amount", DecimalType(38,6)), StructField("fx", DecimalType(38,6))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)
Run Code Online (Sandbox Code Playgroud)
你会得到
+----------+--------+----------+
|amount |fx |amount_usd|
+----------+--------+----------+
|233.000000|1.140322|265.695026|
+----------+--------+----------+
Run Code Online (Sandbox Code Playgroud)
为什么最后的数字是265.695000?这可能是由于Multiply功能中的其他调整。但是你明白了。
从Multiply代码可以看出我们在做乘法的时候要避免使用最大精度,如果我们改成18
schema = StructType([StructField("amount", DecimalType(18,10)), StructField("fx", DecimalType(18,10))])
Run Code Online (Sandbox Code Playgroud)
我们得到这个:
+--------------+------------+------------------------+
|amount |fx |amount_usd |
+--------------+------------+------------------------+
|233.0000000000|1.1403218880|265.69499990400000000000|
+--------------+------------+------------------------+
Run Code Online (Sandbox Code Playgroud)
我们得到了对 python 计算结果的更好的近似:
265.6949999039999754657515041
Run Code Online (Sandbox Code Playgroud)
希望这可以帮助!