[C++] 一個通用協程類模板

[C++] 一個通用協程類模板

源碼

#ifndef __MYCORO_H__
#define __MYCORO_H__

#include <iostream>
#include <experimental/coroutine>

#define DELETE_COPY_FUNCTION(cls) \
	cls (const cls&) = delete; \
	cls& operator= (const cls&) = delete;

#define DELETE_MOVE_FUNCTION(cls) \
	cls (cls&&) = delete; \
	cls& operator= (cls&&) = delete;

#define DELETE_COPY_MOVE_FUNCTION(cls) \
	DELETE_COPY_FUNCTION(cls) \
	DELETE_MOVE_FUNCTION(cls)

namespace my_coro {

	struct Followable {
		virtual bool follow() const noexcept = 0;
		virtual Followable const* next_pfollowable() const noexcept = 0;
		void go() const noexcept {
			Followable const* f = this;
			while(f != nullptr && f->follow()) {
				f = f->next_pfollowable();
			}
		}
	};

	template<typename RsmT>
	struct Resumable : public Followable {
		virtual bool send(RsmT const&) noexcept = 0;
	};
	template<>
	struct Resumable<void> : public Followable {
		virtual bool send() noexcept = 0;
	};

	template<typename RtnT>
	struct CoroPromiseWithReturn {
		std::optional<RtnT> rtn_value;
		void return_value(RtnT&& v) noexcept { rtn_value.emplace(std::move(v)); }
	};
	template<>
	struct CoroPromiseWithReturn<void> {
		void return_void() const noexcept {}
	};

	template<typename RsmT>
	struct CoroPromiseWithResume {
		struct ResumeAwaiter {
			const RsmT* const& d;
			bool await_ready() const noexcept { return false; }
			void await_suspend(std::experimental::coroutine_handle<>) const noexcept {}
			const RsmT& await_resume() const noexcept { return *d; }
		};
		const RsmT* prsm_value;
		auto get_awaiter() noexcept {
			return ResumeAwaiter{ prsm_value };
		}
	};
	template<>
	struct CoroPromiseWithResume<void> {
		auto get_awaiter() noexcept { return std::experimental::suspend_always(); }
	};

	template<typename SpdT>
	struct CoroPromiseWithSuspend {
		const SpdT* pspd_value;
		Resumable<SpdT>* presumable_coroutine;
	};
	template<>
	struct CoroPromiseWithSuspend<void> {};

	template<typename SpdT, typename RsmT, typename RtnT>
	class CoroPromiseCore :
		public CoroPromiseWithSuspend<SpdT>,
		public CoroPromiseWithResume<RsmT>,
		public CoroPromiseWithReturn<RtnT> {
		DELETE_COPY_MOVE_FUNCTION(CoroPromiseCore);
	protected:
		std::optional<std::exception_ptr> pexcept;
	public:
		CoroPromiseCore() noexcept :
			CoroPromiseWithSuspend<SpdT>(),
			CoroPromiseWithResume<RsmT>(),
			CoroPromiseWithReturn<RtnT>(),
			pexcept(std::nullopt) {
		}
	public:
		auto yield_value(const SpdT& v) noexcept {
			this->pspd_value = std::addressof(v);
			return this->get_awaiter();
		}
		auto initial_suspend() const noexcept { return std::experimental::suspend_never(); }
		auto final_suspend() noexcept { return std::experimental::suspend_always(); }
		void unhandled_exception() noexcept {
			try {
				std::rethrow_exception(std::current_exception());
			} catch(std::exception& e) {
				pexcept.emplace(std::make_exception_ptr(new std::exception(e)));
				std::cerr << "exception in stack of the coroutine was copied into heap now: " << e.what() << std::endl;
			} catch(...) {
				pexcept.emplace(std::current_exception());
			}
		}
	public:
		bool follow() const noexcept {
			Resumable<SpdT>* pobj = this->presumable_coroutine;
			if(pobj != nullptr) {
				const SpdT* pspd = this->pspd_value;
				if(pspd != nullptr) {
					return pobj->send(*pspd);
				}
			}
			return false;
		}
		void rethrow_if_failed() const {
			if(*pexcept) {
				std::rethrow_exception(*pexcept);
			}
		}
	};

	template<typename T>
	struct CoroIterator {
		T& co;
		CoroIterator(T & co) noexcept : co(co) {}
		const CoroIterator& operator++ () const noexcept { co.send(); return *this; }
		bool operator!= (CoroIterator const& end) const noexcept { return co; }
		const typename T::suspend_type& operator* () const noexcept { return co.recv(); }
	};

#define CORO_COMMON_FUNCTION(Coro, SpdT, RsmT, RtnT) \
		DELETE_COPY_FUNCTION(Coro);																								\
	Coro& operator= (Coro &&) = delete;																					\
	public:																																			\
		using suspend_type = SpdT;																								\
		using resume_type = RsmT;																									\
		using return_type = RtnT;																									\
		struct promise_type : public CoroPromiseCore<SpdT, RsmT, RtnT> {					\
			using CoroPromiseCore<SpdT, RsmT, RtnT>::CoroPromiseCore;								\
			Coro get_return_object() noexcept {																			\
				return Coro(*this);																										\
			}																																				\
		};																																				\
		using handle_type = std::experimental::coroutine_handle<promise_type>;		\
		static_assert(std::is_void_v<suspend_type> == 0,													\
									"suspend_type can not be void");														\
	protected:																																	\
		promise_type& promise;																										\
		handle_type handle;																												\
	public:																																			\
		explicit Coro(promise_type& promise) noexcept :														\
			promise(promise), handle(handle_type::from_promise(promise)) {}					\
		Coro(Coro &&self) noexcept : promise(self.promise) {}											\
		virtual ~Coro() noexcept { handle.destroy(); }														\
		void rethrow_if_failed() const { return promise.rethrow_if_failed(); }		\
		const SpdT& recv() const noexcept { return *promise.pspd_value; }					\
		bool finalized() const noexcept { return handle.done(); }									\
		template<typename T = RtnT, typename = std::enable_if<										\
			!std::is_void_v<T> && std::is_same_v<T, RtnT>>::type>										\
		const T& get_return() const noexcept { return *promise.rtn_value; }				\
		void link(Resumable<SpdT> & rsmobj) noexcept {														\
			promise.presumable_coroutine = std::addressof(rsmobj);									\
		}																																					\
		void unlink() noexcept { promise.presumable_coroutine = nullptr; }				\
																																							\
	public:																																			\
		operator bool() const noexcept { return !finalized(); }										\
		const SpdT& operator*() const noexcept { return recv(); }									\
		template<typename _SpdT, typename _RtnT>																	\
		Coro<_SpdT, SpdT, _RtnT>& operator | (																		\
			Coro<_SpdT, SpdT, _RtnT>& robj) noexcept {															\
			link(robj);																															\
			return robj;																														\
		}																																					\
																																							\
	public:																																			\
		bool follow() const noexcept override { return promise.follow(); }				\
		Followable const* next_pfollowable() const noexcept override {						\
			Resumable<SpdT> const* prsmobj = promise.presumable_coroutine;					\
			return prsmobj;																													\
		}																																					\
	private:

	template<typename SpdT, typename RsmT = void, typename RtnT = void>
	class Coro final :
		public Resumable<RsmT> {
		CORO_COMMON_FUNCTION(Coro, SpdT, RsmT, RtnT)
	public:
		bool send(RsmT const& v) noexcept override {
			promise.prsm_value = std::addressof(v);
			handle.resume();
			return !finalized();
		}
	};

	template<typename SpdT, typename RtnT>
	class Coro<SpdT, void, RtnT> final :
		public Resumable<void> {
		CORO_COMMON_FUNCTION(Coro, SpdT, void, RtnT)
	public:
		bool send() noexcept override {
			handle.resume();
			return !finalized();
		}
		CoroIterator<Coro> begin() { return CoroIterator<Coro>(*this); }
		CoroIterator<Coro> end() { return CoroIterator<Coro>(*this); }
	};

#undef CORO_COMMON_FUNCTION

}

#undef DELETE_COPY_MOVE_FUNCTION
#undef DELETE_MOVE_FUNCTION
#undef DELETE_COPY_FUNCTION

#endif

使用

基本概念

協程函數

含有co_yieldco_awaitco_return的函數爲協程函數. 協程函數不能有return.

協程函數調用後返回協程對象. 即協程由協程函數而生. 可通過協程對象控制協程運行.

協程函數的返回類型需爲符合要求的協程類. 本文提供的協程類模板可用於生成符合要求的協程類.

協程

協程可向外界發送值, 外界亦可向協程發送值. 前者稱掛起值, 後者稱恢復值.
協程結束後亦可有返回值. 該三類值的類型在本文提供的類模板裏均可自定義.
其中, 協程的掛起值類型不能爲void.

定義協程函數

使用下列語法產生一個協程類. =void表示此處可留空, 默認爲void.

my_coro::Coro<掛起值類型, 恢復值類型=void, 返回值類型=void>

之後可定義協程函數, 令其返回類型爲協程類. 協程使用co_yield向外界發送掛起值和接收恢復值.
下面的協程函數能生成掛起值類型爲int, 恢復值類型爲void, 返回值類型爲void的協程. 這也是生成器型協程.

Coro<int> coro_f() {
	for(int i = 0; i < 10; ++i) {
		co_yield i;
	}
}

生成器型協程

生成器型協程即恢復值類型爲void的協程. 有時簡稱生成器.

生成器型協程不需要恢復值, 因此爲該類協程提供了迭代器. 使用如下代碼依次遍歷協程的所有掛起值.

Coro<掛起值類型> co = coro_f();
for(const 掛起值類型& i : co) {
	對i的操作...
}

若不使用迭代器, 亦可像普通協程那樣遍歷協程.

Coro<int> co = coro_f();
while(co) {
	std::cout << *co << std::endl;
	co.send();
}

STL亦提供有一個生成器, 功能基本相同. 亦可用範圍for循環迭代.

std::experimental::generator<掛起值類型>

普通協程

普通協程即有恢復值的協程. 通過向協程發送恢復值來驅動協程運行.
下面是無返回值的例子.

Coro<int, int> coro_f() {
	int i = -1;
	while(i != 0) {
		i = 10 + (co_yield i);
	}
}

下面是有返回值的例子.

Coro<int, int, int> sum_coro() {
	int s = 0;
	int i;
	do {
		i = co_yield 0;
		s += i;
	} while(i != 0);
	co_return s;
}

外界在協程對象上調用send函數向協程發送恢復值. 使用get_return函數獲取返回值.

Coro<int, int, int> sco = sum_coro();
int i = 0;
while(i <= 100) {
	sco.send(i);
}
std::cout << sco.get_return() << std::endl;

STL的協程類

std::futurestd::experiment::generator是兩個符合要求的協程類. 前者要求協程不能有co_yield, 後者要求協程不能有co_return 非void值;.

當需要快速編寫協程函數時可使用std::future, 以方便快速使用co_await.

常用函數

本文的協程對象實現了operator bool運算符, 可用於直接測試協程是否已返回.

if(co) {
	...
}

同時還實現了operator*運算符, 可用於獲取掛起值. 亦可使用recv()函數.

std::cout << *co << std::endl;
std::cout << co.recv() << std::endl;

可讓一個協程的輸出直接接到另一個的協程的輸入. 只要兩個的協程的掛起值類型和恢復值類型相同. 連接協程可使用管道運算符(按位或運算符)operator|, 亦可使用link()函數.

co1 | co2

co1的掛起值將直接輸入給co2. 若要啓動協程的多米諾式運行, 可在第一個協程上調用go(). 所有協程都實現了Followable接口, 可使用該接口自行控制協程的鏈式運行.

co1.go();

一些調用檢查

send()函數在協程的恢復值類型爲void時無參數, 非void時有參數.
get_return()函數在協程的返回值類型定義爲void時不可調用.

異常

協程內拋出異常後, 將直接退出. 協程對象將記錄該異常. 外界可用rethrow_if_exception()重拋出該異常.
棧上C++異常將被複制, 以讓外界訪問. 即std::exception&型異常將轉爲std::exception*, 注意記得調用delete釋放內存. 建議不要拋出棧上異常, 除非自行修改源碼令其支持複製其他棧上的自定義異常.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章