使用 `Decimal`模塊執行精確的浮點數運算

Python v3.10.6

背景

浮點數的一個普遍問題是它們並不能精確的表示十進制數。 並且,即使是最簡單的數學運算也會產生小的誤差,比如:

>>> 4.2+2.1
6.300000000000001

這些錯誤是由底層特徵決定的,因此沒辦法去避免這樣的誤差。這時候可能首先想到的是使用內置的round(value, ndigits)函數,但round函數採用的舍入規則是四捨六入五成雙,也就是說如果value剛好在兩個邊界的中間的時候, round 函數會返回離它最近的偶數(如下示例),除非對精確度沒什麼要求,否則儘量避開用此方法。

>>> round(1.5) == round(2.5) == 2
True
>>> round(2.675, 2)
2.67

解決辦法

在一些對浮點數精度要求較高的領域,需要使用Decimal模塊中的方法來進行精準計算,官方文檔:https://docs.python.org/zh-cn/3/library/decimal.html

基於Decimal封裝的工具類,部分源碼如下:

from typing import Union
from decimal import ROUND_HALF_UP, Context, Decimal, getcontext

Numeric = Union[int, float, str, Decimal]

def isNumeric(n: Numeric, raise_exception=False) -> bool:
    if isinstance(n, (int, float, Decimal)):
        flag = True
    elif isinstance(n, str) and n.replace(".", "").lstrip("+").lstrip("-").isdigit():  # not stirp
        flag = True
    else:
        flag = False

    if raise_exception and flag is False:
        raise ValueError(f"Unsupport value type: {n}, type: {type(n)}")

    return flag


class DecimalTool:
    """十進制浮點運算工具類.
    計算結果與服務端decimal方法的計算結果保持完全一致, 方便在case中對數據進行精確斷言.
    同時實例方法支持鏈式調用, 降低python弱語言類型造成的困擾.

    使用示例:
    DecimalTool("0.348526").round("0.000").scale(3).toString()
    >>> 348.226
    """
    def __init__(self, n: Numeric, rounding: str = None, context: Context = None):
        isNumeric(n, True)

        ctx = context if context else getcontext()
        self.n = Decimal(str(n))

        # 設置舍入模式爲四捨五入, 即若最後一個有效數字小於5則朝0方向取整,否則朝0反方向取整
        self.rounding = ROUND_HALF_UP if not rounding else rounding
        ctx.rounding = self.rounding
        ctx.prec = 28

        # init
        self._cache = self.n

    def round(self, exp: str):
        """將數值四捨五入到指定精度.
        Usage:
        >>> roundPrice1 = DecimalTool('3.14145').round('0.0000').toString()
        >>> Assert.assert_equal(roundPrice1, '3.1415')
        >>> Assert.assert_not_equal(roundPrice1, round(3.14145, 4))
        """
        self._cache = self.n.quantize(Decimal(exp), self.rounding)
        return self

    def truncate(self, exp):
        """將數值截取到指定精度(不四捨五入)
        Usage:
        >>> DecimalTool('3.1415').truncate('0.000').toString()
        3.141
        """
        sourcePrecision = DecimalTool(self._cache).getPrecision()
        targetPrecision = DecimalTool(exp).getPrecision()
        if sourcePrecision > targetPrecision:
            t = str(self._cache).split('.')
            self._cache = Decimal(t[0] + "." + t[1][:targetPrecision])

        return self

    def scale(self, e: int):
        """以10爲底, 將參數進行指數級別縮放(e可爲負數)
        Usage:
        >>> DecimalTool('0.348526').scale(3).toString()
        348.526
        >>> DecimalTool('348.526').scale(-3).toString()
        0.348526
        """
        self._cache = self.n.scaleb(e)
        return self

    def cutdown(self, size: Numeric, mode: Literal['truncate', 'round'] = 'truncate'):
        """根據根據size, 將數值進行向下裁減處理, 並與size保持相同精度.
        Uasge:
        >>> DecimalTool("16000.6").cutdown('0.5').toString()
        16000.5
        >>> DecimalTool("16000.6").cutdown('0.50').toString()
        16000.50
        """
        isNumeric(size, True)
        temp = Decimal(size) * int(self._cache // Decimal(size))

        if mode == 'truncate':
            self._cache = DecimalTool(temp).truncate(size).toDecimal()
        elif mode == 'round':
            self._cache = DecimalTool(temp).round(size).toDecimal()
        else:
            raise ValueError(f"Unsupport mode: {mode}")
        return self

    def isZero(self) -> bool:
        """如果參數爲0, 則返回True, 否則返回False
        Usage:
        >>> DecimalTool('0.001').isZero()
        Flase
        >>> DecimalTool('0.00').isZero()
        True
        """
        return self._cache.is_zero()

    def isSigend(self) -> bool:
        """如果參數帶有負號,則返回爲True, 否則返回False
        特殊地, -0會返回Flase
        """
        return self._cache.is_signed()

    def deleteExtraZero(self):
        """刪除小數點後面多餘的0
        """
        self._cache = Decimal(str(self._cache).rstrip('0').rstrip("."))
        return self

    def getPrecision(self) -> int:
        """獲取小數位精度
        """
        self._cache_str = str(self._cache).lower()
        if "." in self._cache_str:
            precision = len(self._cache_str.split('.')[-1])
        elif "e-" in self._cache_str:  # 某些情況下, 如當小數位數>6位時, 會變爲科學計數法

            _, precision = self._cache_str.replace('-', '').split('e')
            precision = int(precision)
        else:
            precision = 0

        return precision

    def toDecimal(self) -> Decimal:
        return self._cache

    def toInt(self) -> int:
        return self._cache.to_integral_value()

    def toFloat(self) -> float:
        return float(self._cache)

    def toString(self) -> str:
        return str(self._cache)

    def toEngString(self) -> str:
        return self._cache.to_eng_string()

Testcase (Pytest)

...

def test_decimal_tool():
    precision1 = DecimalTool('0.348226000').getPrecision()
    Assert.assert_equal(precision1, 9)
    precision2 = DecimalTool('0.00000001').getPrecision()
    Assert.assert_equal(precision2, 8)
    precision3 = DecimalTool('123456').getPrecision()
    Assert.assert_equal(precision3, 0)

    roundPrice1 = DecimalTool('3.14145').round("0.0000").toString()
    Assert.assert_equal(roundPrice1, '3.1415')
    Assert.assert_not_equal(roundPrice1, round(3.14145, 4))
    roundPrice2 = DecimalTool('0.348526').round("0.000").toString()
    Assert.assert_equal(roundPrice2, '0.349')

    truncatePrice1 = DecimalTool('0.0348').truncate("0.000").toString()
    Assert.assert_equal(truncatePrice1, '0.034')
    truncatePrice2 = DecimalTool('0.0348').truncate("0.0000000").toString()
    Assert.assert_equal(truncatePrice2, '0.0348')

    scaledPrice1 = DecimalTool('0.348526').round("0.000").scale(3).toString()
    Assert.assert_equal(scaledPrice1, '348.526')
    scaledPrice2 = DecimalTool('348.526').round("0.000").scale(-3).toString()
    Assert.assert_equal(scaledPrice2, '0.348526')

    cutdownPrice1 = DecimalTool("16000.6").cutdown('0.5').toString()
    Assert.assert_equal(cutdownPrice1, '16000.5')
    cutdownPrice1 = DecimalTool("16000.6").cutdown('0.50').toString()
    Assert.assert_equal(cutdownPrice1, '16000.50')  # 與size保持相同精度

    isNotZero = DecimalTool('0.001').isZero()
    Assert.assert_not_true(isNotZero)
    isZero1 = DecimalTool('0.00').isZero()
    Assert.assert_true(isZero1)
    isZero2 = DecimalTool('0').isZero()
    Assert.assert_true(isZero2)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章