STL源碼解析 - nth_element

 

nth_element 模板函數具有兩個版本

 

template<class _RanIt>
void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last);


template<class _RanIt, class _Pr>
void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last, _Pr _Pred);


其功能是對區間 [_First, _Last) 的元素進行重排,其中位於位置 _Nth 的元素與整個區間排序後位於位置 _Nth 的元素相同,並且滿足在位置 _Nth 之前的所有元素都“不大於”它和位置 _Nth 之後的所有元素都“不小於”它,而且並不保證 _Nth 的前後兩個區間的所有元素保持有序。

第一個版本,比較操作默認使用小於操作符(operator<);第二個版本,使用自定義謂詞 "_Pred" 定義“小於”操作(Less Than)。

算法的空間複雜度爲O(1)。

由於算法主要分兩部分實現,第一部分是進行二分法弱分區,第二部分是對包含 _Nth 的位置的區間進行插入排序(STL的閾值爲32)。當元素較多時平均時間複雜度爲O(N),元素較少時最壞情況下時間複雜度爲O(N^2)。

下面針對第一個版本的算法源代碼進行註釋說明,版本爲 Microsoft Visual Studio 2008 SP1 安裝包中的 algorithm 文件

 

template<class _RanIt> inline
void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last)
{	// order Nth element, using operator<
    _Nth_element(_CHECKED_BASE(_First), _CHECKED_BASE(_Nth), _CHECKED_BASE(_Last)); // 轉調用內部實現函數
}


_Nth_element 函數實現,其中 _ISORT_MAX 值爲 32。

 

template<class _RanIt> inline
	void _Nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last)
	{	// order Nth element, using operator<
	_DEBUG_RANGE(_First, _Last);
	for (; _ISORT_MAX < _Last - _First; )
		{	// divide and conquer, ordering partition containing Nth
		pair<_RanIt, _RanIt> _Mid =
			std::_Unguarded_partition(_First, _Last);

		if (_Mid.second <= _Nth)
			_First = _Mid.second;
		else if (_Mid.first <= _Nth)
			return;	// Nth inside fat pivot, done
		else
			_Last = _Mid.first;
		}

	// 插入排序
	std::_Insertion_sort(_First, _Last);	// sort any remainder
	}

 

_Unguarded_partition 函數實現

 

template<class _RanIt> inline
	pair<_RanIt, _RanIt> _Unguarded_partition(_RanIt _First, _RanIt _Last)
	{	// partition [_First, _Last), using operator<
	_RanIt _Mid = _First + (_Last - _First) / 2;	// sort median to _Mid
	std::_Median(_First, _Mid, _Last - 1);	// 端點排序
	_RanIt _Pfirst = _Mid;
	_RanIt _Plast = _Pfirst + 1;	// 起始返回區間爲 [_Mid, _Mid + 1)

	// 以下兩個循環將不處理與 *_Mid 值相同的元素
	while (_First < _Pfirst
		&& !_DEBUG_LT(*(_Pfirst - 1), *_Pfirst)
		&& !(*_Pfirst < *(_Pfirst - 1)))
		--_Pfirst;
	while (_Plast < _Last
		&& !_DEBUG_LT(*_Plast, *_Pfirst)
		&& !(*_Pfirst < *_Plast))
		++_Plast;

	// 當前返回區間爲 [_Pfirst, _Plast),且區間內值均相等
	_RanIt _Gfirst = _Plast;
	_RanIt _Glast = _Pfirst;

	for (; ; )
		{	// partition
		// 後半區間
		for (; _Gfirst < _Last; ++_Gfirst)
			if (_DEBUG_LT(*_Pfirst, *_Gfirst))	// 大於首值,迭代器後移
				;
			else if (*_Gfirst < *_Pfirst)		// 小於首值,退出循環
				break;
			else
				std::iter_swap(_Plast++, _Gfirst);	// 與首值相等,末迭代器後移,更新末值
		// 前半區間
		for (; _First < _Glast; --_Glast)
			if (_DEBUG_LT(*(_Glast - 1), *_Pfirst))	// 小於首值,迭代器前移
				;
			else if (*_Pfirst < *(_Glast - 1))		// 大於首值,退出循環
				break;
			else
				std::iter_swap(--_Pfirst, _Glast - 1);	// 與首值相等,首迭代器前移,更新首值

		// 整體區間已經處理結束
		if (_Glast == _First && _Gfirst == _Last)
			return (pair<_RanIt, _RanIt>(_Pfirst, _Plast));

		// 到達起點
		if (_Glast == _First)
			{	// no room at bottom, rotate pivot upward
			if (_Plast != _Gfirst)
				std::iter_swap(_Pfirst, _Plast);	// if 成立,_Pfirst 暫存大值
			++_Plast;								// 末迭代器後移
			std::iter_swap(_Pfirst++, _Gfirst++);	// if 成立時,小值將存於返回區間首,最終結果是,返回區間整體右移
			}
		else if (_Gfirst == _Last)	// 到達終點
			{	// no room at top, rotate pivot downward
			if (--_Glast != --_Pfirst)
				std::iter_swap(_Glast, _Pfirst);	// if 成立,_Pfirst 暫存大值
			std::iter_swap(_Pfirst, --_Plast);	// if 成立時,大值將存於返回區間尾,最終結果是,返回區間整體左移
			}
		else
			std::iter_swap(_Gfirst++, --_Glast);	// 交換後,*_Glast < *_Pfirst < *(_Gfirst - 1)
		}
	}

 

_Median 和 _Med3 兩個函數,其作用是對區間內的特定幾個數進行排序

 

template<class _RanIt> inline
	void _Med3(_RanIt _First, _RanIt _Mid, _RanIt _Last)
	{	// sort median of three elements to middle - 3 點排序
	if (_DEBUG_LT(*_Mid, *_First))
		std::iter_swap(_Mid, _First);
	if (_DEBUG_LT(*_Last, *_Mid))
		std::iter_swap(_Last, _Mid);
	if (_DEBUG_LT(*_Mid, *_First))
		std::iter_swap(_Mid, _First);
	}

template<class _RanIt> inline
	void _Median(_RanIt _First, _RanIt _Mid, _RanIt _Last)
	{	// sort median element to middle
	if (40 < _Last - _First)
		{	// median of nine - 9 端點排序
		size_t _Step = (_Last - _First + 1) / 8;
		std::_Med3(_First, _First + _Step, _First + 2 * _Step);
		std::_Med3(_Mid - _Step, _Mid, _Mid + _Step);
		std::_Med3(_Last - 2 * _Step, _Last - _Step, _Last);
		std::_Med3(_First + _Step, _Mid, _Last - _Step);
		}
	else
		std::_Med3(_First, _Mid, _Last);
	}

 

對於第二個版本,算法思想相同,只是要做比較操作時,將用 _Pred 替換 operator< 操作符,同時也看到算法的核心主要在於 _Unguarded_partition 這個函數。

_Insertion_sort 函數,插入排序

 

template<class _BidIt> inline
	void _Insertion_sort(_BidIt _First, _BidIt _Last)
	{	// insertion sort [_First, _Last), using operator<
	std::_Insertion_sort1(_First, _Last, _Val_type(_First)); // 轉調用 _Insertion_sort1
	}


_Insertion_sort1 函數

 

template<class _BidIt,
	class _Ty> inline
	void _Insertion_sort1(_BidIt _First, _BidIt _Last, _Ty *)
	{	// insertion sort [_First, _Last), using operator<
	if (_First != _Last)
		for (_BidIt _Next = _First; ++_Next != _Last; )
			{	// order next element
			_BidIt _Next1 = _Next;
			_Ty _Val = *_Next;

			// 小於首值時,整體後移,有可能使用 memmove,因而存在優化
			if (_DEBUG_LT(_Val, *_First))
				{	// found new earliest element, move to front - [_First, _Next) => [..., ++Next1)
				_STDEXT unchecked_copy_backward(_First, _Next, ++_Next1);
				*_First = _Val;
				}
			else
				{	// look for insertion point after first
				for (_BidIt _First1 = _Next1;
					_DEBUG_LT(_Val, *--_First1);
					_Next1 = _First1)
					*_Next1 = *_First1;	// move hole down - 逐項後移
				*_Next1 = _Val;	// insert element in hole
				}
			}
	}

 

至此,我們已經完全理解 nth_element 的算法思想了,並且明白爲何它的時間複雜度和空間複雜度都很低,當不需要對某個數組進行全部排序而想找出滿足某一條件(_Pred)的第 N 個值時,便可採用此算法,同時需要注意的是,此算法只對“隨機訪問迭代器”有效(如 vector),如果需要對 list 使用此算法,可先將 list 的所有元素拷貝至 vector(或者存儲 list::iterator,對自定義類型效率更高),再使用此算法。

 

代碼版本來源於Microsoft Visual Studio 2008 安裝包中<algorithm>文件,版權歸原作者所有!

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章