使用 `Decimal`模块执行精确的浮点数运算

Python v3.10.6

背景

浮点数的一个普遍问题是它们并不能精确的表示十进制数。 并且,即使是最简单的数学运算也会产生小的误差,比如:

>>> 4.2+2.1
6.300000000000001

这些错误是由底层特征决定的,因此没办法去避免这样的误差。这时候可能首先想到的是使用内置的round(value, ndigits)函数,但round函数采用的舍入规则是四舍六入五成双,也就是说如果value刚好在两个边界的中间的时候, round 函数会返回离它最近的偶数(如下示例),除非对精确度没什么要求,否则尽量避开用此方法。

>>> round(1.5) == round(2.5) == 2
True
>>> round(2.675, 2)
2.67

解决办法

在一些对浮点数精度要求较高的领域,需要使用Decimal模块中的方法来进行精准计算,官方文档:https://docs.python.org/zh-cn/3/library/decimal.html

基于Decimal封装的工具类,部分源码如下:

from typing import Union
from decimal import ROUND_HALF_UP, Context, Decimal, getcontext

Numeric = Union[int, float, str, Decimal]

def isNumeric(n: Numeric, raise_exception=False) -> bool:
    if isinstance(n, (int, float, Decimal)):
        flag = True
    elif isinstance(n, str) and n.replace(".", "").lstrip("+").lstrip("-").isdigit():  # not stirp
        flag = True
    else:
        flag = False

    if raise_exception and flag is False:
        raise ValueError(f"Unsupport value type: {n}, type: {type(n)}")

    return flag


class DecimalTool:
    """十进制浮点运算工具类.
    计算结果与服务端decimal方法的计算结果保持完全一致, 方便在case中对数据进行精确断言.
    同时实例方法支持链式调用, 降低python弱语言类型造成的困扰.

    使用示例:
    DecimalTool("0.348526").round("0.000").scale(3).toString()
    >>> 348.226
    """
    def __init__(self, n: Numeric, rounding: str = None, context: Context = None):
        isNumeric(n, True)

        ctx = context if context else getcontext()
        self.n = Decimal(str(n))

        # 设置舍入模式为四舍五入, 即若最后一个有效数字小于5则朝0方向取整,否则朝0反方向取整
        self.rounding = ROUND_HALF_UP if not rounding else rounding
        ctx.rounding = self.rounding
        ctx.prec = 28

        # init
        self._cache = self.n

    def round(self, exp: str):
        """将数值四舍五入到指定精度.
        Usage:
        >>> roundPrice1 = DecimalTool('3.14145').round('0.0000').toString()
        >>> Assert.assert_equal(roundPrice1, '3.1415')
        >>> Assert.assert_not_equal(roundPrice1, round(3.14145, 4))
        """
        self._cache = self.n.quantize(Decimal(exp), self.rounding)
        return self

    def truncate(self, exp):
        """将数值截取到指定精度(不四舍五入)
        Usage:
        >>> DecimalTool('3.1415').truncate('0.000').toString()
        3.141
        """
        sourcePrecision = DecimalTool(self._cache).getPrecision()
        targetPrecision = DecimalTool(exp).getPrecision()
        if sourcePrecision > targetPrecision:
            t = str(self._cache).split('.')
            self._cache = Decimal(t[0] + "." + t[1][:targetPrecision])

        return self

    def scale(self, e: int):
        """以10为底, 将参数进行指数级别缩放(e可为负数)
        Usage:
        >>> DecimalTool('0.348526').scale(3).toString()
        348.526
        >>> DecimalTool('348.526').scale(-3).toString()
        0.348526
        """
        self._cache = self.n.scaleb(e)
        return self

    def cutdown(self, size: Numeric, mode: Literal['truncate', 'round'] = 'truncate'):
        """根据根据size, 将数值进行向下裁减处理, 并与size保持相同精度.
        Uasge:
        >>> DecimalTool("16000.6").cutdown('0.5').toString()
        16000.5
        >>> DecimalTool("16000.6").cutdown('0.50').toString()
        16000.50
        """
        isNumeric(size, True)
        temp = Decimal(size) * int(self._cache // Decimal(size))

        if mode == 'truncate':
            self._cache = DecimalTool(temp).truncate(size).toDecimal()
        elif mode == 'round':
            self._cache = DecimalTool(temp).round(size).toDecimal()
        else:
            raise ValueError(f"Unsupport mode: {mode}")
        return self

    def isZero(self) -> bool:
        """如果参数为0, 则返回True, 否则返回False
        Usage:
        >>> DecimalTool('0.001').isZero()
        Flase
        >>> DecimalTool('0.00').isZero()
        True
        """
        return self._cache.is_zero()

    def isSigend(self) -> bool:
        """如果参数带有负号,则返回为True, 否则返回False
        特殊地, -0会返回Flase
        """
        return self._cache.is_signed()

    def deleteExtraZero(self):
        """删除小数点后面多余的0
        """
        self._cache = Decimal(str(self._cache).rstrip('0').rstrip("."))
        return self

    def getPrecision(self) -> int:
        """获取小数位精度
        """
        self._cache_str = str(self._cache).lower()
        if "." in self._cache_str:
            precision = len(self._cache_str.split('.')[-1])
        elif "e-" in self._cache_str:  # 某些情况下, 如当小数位数>6位时, 会变为科学计数法

            _, precision = self._cache_str.replace('-', '').split('e')
            precision = int(precision)
        else:
            precision = 0

        return precision

    def toDecimal(self) -> Decimal:
        return self._cache

    def toInt(self) -> int:
        return self._cache.to_integral_value()

    def toFloat(self) -> float:
        return float(self._cache)

    def toString(self) -> str:
        return str(self._cache)

    def toEngString(self) -> str:
        return self._cache.to_eng_string()

Testcase (Pytest)

...

def test_decimal_tool():
    precision1 = DecimalTool('0.348226000').getPrecision()
    Assert.assert_equal(precision1, 9)
    precision2 = DecimalTool('0.00000001').getPrecision()
    Assert.assert_equal(precision2, 8)
    precision3 = DecimalTool('123456').getPrecision()
    Assert.assert_equal(precision3, 0)

    roundPrice1 = DecimalTool('3.14145').round("0.0000").toString()
    Assert.assert_equal(roundPrice1, '3.1415')
    Assert.assert_not_equal(roundPrice1, round(3.14145, 4))
    roundPrice2 = DecimalTool('0.348526').round("0.000").toString()
    Assert.assert_equal(roundPrice2, '0.349')

    truncatePrice1 = DecimalTool('0.0348').truncate("0.000").toString()
    Assert.assert_equal(truncatePrice1, '0.034')
    truncatePrice2 = DecimalTool('0.0348').truncate("0.0000000").toString()
    Assert.assert_equal(truncatePrice2, '0.0348')

    scaledPrice1 = DecimalTool('0.348526').round("0.000").scale(3).toString()
    Assert.assert_equal(scaledPrice1, '348.526')
    scaledPrice2 = DecimalTool('348.526').round("0.000").scale(-3).toString()
    Assert.assert_equal(scaledPrice2, '0.348526')

    cutdownPrice1 = DecimalTool("16000.6").cutdown('0.5').toString()
    Assert.assert_equal(cutdownPrice1, '16000.5')
    cutdownPrice1 = DecimalTool("16000.6").cutdown('0.50').toString()
    Assert.assert_equal(cutdownPrice1, '16000.50')  # 与size保持相同精度

    isNotZero = DecimalTool('0.001').isZero()
    Assert.assert_not_true(isNotZero)
    isZero1 = DecimalTool('0.00').isZero()
    Assert.assert_true(isZero1)
    isZero2 = DecimalTool('0').isZero()
    Assert.assert_true(isZero2)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章