std::variant visitor與pattern matching

業餘民科,拾人牙慧,垃圾內容

std::variant

我在《CppCon 2016: Ben Deane “Using Types Effectively" 筆記》中提到了Ben認爲std::variantstd::optional是C++最重要的新特性。但是在筆記中,我只提到了std::variant是type-safe的union,與ML或者Haskell中pattern matching相關。這裏就介紹與std::variant相關的std::visitorpattern matching

除了type-safe union之外,更爲重要的是std::variant可以替代傳統的一些基於inheritance多態的使用。CppCon 2016: David Sankel “Variants: Past, Present, and Future"給出了一個簡單的總結。

Inheritance Variant
open to new alternatives closed to new alternatives
closed to new operations open to new operations
multi-level single level
OO functional
complext simple

pattern matching

我唯一熟悉的functional programming language是ML(在coursera上Programming Language時學的),這裏就以ML中的pattern matching爲例。

datatype exp = Constant of int
			| Negate    of exp
			| Add       of exp * exp
			| Multiply  of exp * exp

注:上面的代碼摘抄自《coursera - Programming Languages》
datatype用來聲明新的數據類型,像std::variant一樣可以稱之爲sum type,disjointed union等等。上述的代碼遞歸地定義了一個表示數學表達式的類型exp

我們可以寫一個函數來找其中最大constant。

fun max_constant e = 
	case e of 
			Constant i => i
		| Negate e2 => max_constant e2
		| Add(e1, e2) => if max_constant e1 > max_constant e2
						then max_constant e1
						else max_constant e2
		| Multiply(e1, e2) => if max_constant e1 > max_constant e2
						then max_constant e1
						else max_constant e2

注:上面的代碼摘抄自《coursera - Programming Language》

上面代碼中的case表達式實現的就是pattern matching的效果,C++17之前的代碼並不支持pattern matching,但是我們可以使用OOP的方式模擬它。

#include <algorithm>
#include <memory>
#include <iostream>
class exp {
public:
  virtual const int max_constant() = 0;
};

class Constant : public exp {
  int v;

public:
  Constant(int v) : v(v) {}
  int max_constant() const override final { return v; }
};

class Negate : public exp {
  int v;

public:
  Negate(int v) : v(v) {}
  int max_constant() const override final { return v; }
};

class Add : public exp {
  std::unique_ptr<exp> e1;
  std::unique_ptr<exp> e2;

public:
  Add(std::unique_ptr<exp> e1, std::unique_ptr<exp> e2)
      : e1(std::move(e1)), e2(std::move(e2)) {}
  int max_constant() const override final {
    return std::max(e1->max_constant(), e2->max_constant());
  }
};

class Multiply : public exp {
  std::unique_ptr<exp> e1;
  std::unique_ptr<exp> e2;

public:
  Multiply(std::unique_ptr<exp> e1, std::unique_ptr<exp> e2)
      : e1(std::move(e1)), e2(std::move(e2)) {}
  int max_constant() const override final {
    return std::max(e1->max_constant(), e2->max_constant());
  }
};

pattern matching is a dispatch mechanism: choosing which variant of a function is the correct one to call. - 《Pattern Matching

如果std::variant僅僅只是type-safe union,那麼並不能釋放std::variant的潛力,需要提供“配套”的處理std::variant的機制。

single dispatch

single dispatch就是我們所說的dynamic dispatch和static dispatch,而在《Programming Language, Part C - 第一週上課筆記》中,Dan提到dynmaic dispatch是OOP中最本質的東西。

dynamic dispatch

In computer science, dynamic dispatch is the process of selecting which implementation of a polymorphic operation (method or function) to call at run time. It is commonly employed in, and considered a prime characteristic of, object-oriented programming (OOP) languages and systems.

上面的C++代碼就是用vtable實現的single dispatch(dynamic dispatch)。

dynamic dispatch的總結源於《CppCon 2018: Mateusz Pusz “Effective replacement of dynamic polymorphism with std::variant”

single dynamic dispatch

  • Open to new alternatives
    - new derived types may be added by clients at any point of time (long after base class implementation is finished)
  • Closed to new operations
    - clients cannot add new operations to dynamic dispatch
  • Multi-level
    - many level of inheritance possible
  • Object Oriented
    - whole framework is based on objects

對應的class類圖如下所示:
dynamic dispatch

static dispatch

In computing, static dispatch is a form of polymorphism fully resolved during compile time. It is a form of method dispatch, which describes how a language or environment will select which implementation of a method or function to use.

dynamic dispatch就是virtual function通過vtable和RTTI實現的,static dispatch就是通過CRTP實現的。例如下面的代碼就是使用CRTP實現的single dispatch(static dispatch)。

#include <algorithm>
#include <memory>
#include <iostream>

template<typename Derived>
class Exp {
public:
  const int max_constant() {
    return static_cast<Derived*>(this)->max_constant();
  }
};

class Constant : public Exp<Constant> {
  int v;

public:
  Constant(int v) : v(v) {}
  int max_constant() const { return v; }
};

class Negate : public Exp<Negate> {
  int v;

public:
  Negate(int v) : v(v) {}
  int max_constant() const { return v; }
};

template<typename T, typename U>
class Add : public Exp<Add<T, U>> {
  std::unique_ptr<Exp<T>> e1;
  std::unique_ptr<Exp<U>> e2;

public:
  Add(std::unique_ptr<Exp<T>> e1, std::unique_ptr<Exp<U>> e2)
      : e1(std::move(e1)), e2(std::move(e2)) {}
  int max_constant() const {
    return std::max(e1->max_constant(), e2->max_constant());
  }
};

template<typename T, typename U>
class Multiply : public Exp<Multiply<T, U>> {
  std::unique_ptr<Exp<T>> e1;
  std::unique_ptr<Exp<U>> e2;

public:
  Multiply(std::unique_ptr<Exp<T>> e1, std::unique_ptr<Exp<U>> e2)
      : e1(std::move(e1)), e2(std::move(e2)) {}
  int max_constant() const {
    return std::max(e1->max_constant(), e2->max_constant());
  }
};

template<typename T>
int max_const(Exp<T>&& e) { return e.max_constant(); }

int main() {
	// 我還沒有找到如何省去class template argument的寫法
  auto e = Multiply<Add<Negate, Constant>, Multiply<Constant, Constant>>(
    std::make_unique<Add<Negate, Constant>>(std::make_unique<Negate>(10), std::make_unique<Constant>(10)), 
    std::make_unique<Multiply<Constant, Constant>>(std::make_unique<Constant>(10), std::make_unique<Constant>(30)));
  std::cout << max_const(std::move(e)) << std::endl;
}

我們可以看到CRTP的方式的核心在於template繼承 + static_cast,但是寫模板太痛苦了,得到的好處就是這一切都是在compile time完成的。

single dispatch就是根據特定的類型,執行類型對應的方法或函數。

double dispatch(visitor pattern)

Visitor Pattern對double dispatch進行了簡單的解釋,

Multiple dispatch is a concept that allows method dispatch to be based not only on the receiving object but also on the parameters of the method’s invocation.

爲了解釋double dispatch是做什麼的,我們以AST爲例來解釋。例如我們要遍歷語法樹,打印節點信息。

class StatementAST {
public:
	virtual void print() = 0;
};

class Expr : public StatementAST {
public:
	void print() { std::cout << "This is Expr" << '\n'; }
};

class NumberExpr : public Expr {
public:
	void print() { std::cout << "This is NumberExpr" << '\n'; }
};

class StringLiteral : public Expr {
public:
	void print() { std::cout << "This is StringLiteral" << '\n'; }
};

// ...
class CallExpr : public Expr {
public:
	void print() { 
		std::cout << name << "(";
		for (const auto & a : argu) {
			a->print();
			std::cout << ", ";
		}
		std::cout << name << ")";
	}
};

很簡單,我們把print()作爲virtual method加到各AST node中。但是類似於這樣的需要對不同AST node進行不同處理的需要還有很多,例如semantic analysis或者生成IR。仿照着print(),我們可以爲每個節點加上對應的virtual method,例如emitIR()

class StatementAST {
public:
	virtual void print() = 0;
	virtual ir emitIr() = 0;
};

class Expr : public StatementAST {
public:
	void print() { std::cout << "This is Expr" << '\n'; }
};

class NumberExpr : public Expr {
public:
	void print() { std::cout << "This is NumberExpr" << '\n'; }
	ir emitIr() {/* */};
};

class StringLiteral : public Expr {
public:
	void print() { std::cout << "This is StringLiteral" << '\n'; }
	ir emitIr() {/* */};
};

// ...
class CallExpr : public Expr {
public:
	void print() { 
		std::cout << name << "(";
		for (const auto & a : argu) {
			a->print();
			std::cout << ", ";
		}
		std::cout << name << ")";
	}
	ir emitIr() {/* (1) emit code for argument (2) emit call expression */};
};

但是這樣的需求還有很多很多,把生成IR的代碼直接放到AST node中不夠模塊化,前中後端摻雜在一起。一種可行的實現方式如下:

class Visitor {
public:
	virtual void visit(const StatementAST *S) = 0;
	virtual void visit(const Expr *E) = 0;
	virtual void visit(const NumberExpr *NE) = 0;
	virtual void visit(const StringLiteral *SL) = 0;
	// ...
	virtual void visit(const CallExpr *CE) = 0;
};
class StatementAST {
public:
	virtual void accept(Visitor& visitor) = 0;
};

class Expr : public StatementAST {
public:
	void accept(Visitor& visitor) override final {
		visitor.Visit(*this);
	}
};

class NumberExpr : public Expr {
public:
	void accept(Visitor& visitor) override final {
		visitor.Visit(*this);
	}
};

class StringLiteral : public Expr {
public:
	void accept(Visitor& visitor) override final {
		visitor.Visit(*this);
	}
};

// ...
class CallExpr : public Expr {
public:
	void accept(Visitor& visitor) override final {
		visitor.Visit(*this);
	}
};

// 你可以定義多種visitor,例如print visitor,或者IR generate visitor.
class PrintVisitor : public Visitor {
public:
	void visit(const StatementAST *S) {
		std::cout << "This is StatementAST" << '\n';
	}
	void visit(const Expr *E) {
		std::cout << "This is Expr" << '\n';
	}
	void visit(const NumberExpr *NE) {
		std::cout << "This is NumberExpr" << '\n';
	}
	void visit(const StringLiteral *SL) {
		std::cout << "This is StringLiteral" << '\n';
	}
	// ...
	void visit(const CallExpr *CE) {
		std::cout << name << "(";
		for (const auto & a : argu) {
			a->accept(*this);
			std::cout << ", ";
		}
		std::cout << name << ")";
	}
}

class IRGenerateVisitor : public Visitor {
public:
	void visit(const StatementAST *S) { /**/ }
	void visit(const Expr *E) { /**/ }
	void visit(const NumberExpr *NE) { /**/ }
	void visit(const StringLiteral *SL) { /**/ }
	// ...
	void visit(const CallExpr *CE) { /**/ }
};

double dispatch對應的圖形如下所示:
double dispatch
但是這裏的visitor pattern還不是很完善,每次添加一個新的AST node class,我們都需要修改visitor。這裏我們使用兩次vtable實現double dispatch,但是我們也可以重載accept method,從而得到不同的行爲,此時就是vtable + overload實現double dispatch。

同樣是《CppCon 2018: Mateusz Pusz “Effective replacement of dynamic polymorphism with std::variant”
給出了double dispatch的總結。
double dynamic dispatch

  • Open to new alternatives (因爲你需要同時修改visitor)
  • Closed to new operations
    - clients cannot add new operations to dynamic dispatch
  • Multi-level
    - many level of inheritance possible
  • Object Oriented
    - whole framework is based on objects

std::visit

那麼如何用std::variant來表達我們最前面提到的Exp例子呢,事實上沒有直接的方式實現,本質上是C++沒有recursive variant,聲明時需要complete type。下面的代碼是編譯不過的,相關問題《C++ Mutually Recursive Variant Type (Again)》。《C++ 17 in detail》這本書列出了boost::variant和std::variant的對比,如下。
對比

class Constant {
  int v;

public:
  Constant(int v) : v(v) {}
  int value() const { return v; }
};

class Negate {
  int v;

public:
  Negate(int v) : v(v) {}
  int value() const { return v; }
};

class Add {
  std::variant<Constant, Negate, Add, Multiply> e1;
  std::variant<Constant, Negate, Add, Multiply> e2;

public:
  Add(const std::variant<Constant, Negate, Add, Multiply>& e1, std::variant<Constant, Negate, Add, Multiply>& e2) : e1(e1), e2(e2) {}
  std::variant<Constant, Negate, Add, Multiply> getOperand1() const { return e1; }
  std::variant<Constant, Negate, Add, Multiply> getOperand2() const { return e2; }
};

class Multiply {
  std::variant<Constant, Negate, Add, Multiply> e1;
  std::variant<Constant, Negate, Add, Multiply> e2;

public:
  Multiply(std::variant<Constant, Negate, Add, Multiply> e1, std::variant<Constant, Negate, Add, Multiply> e2) : e1(e1), e2(e2) {}
  std::variant<Constant, Negate, Add, Multiply> getOperand1() const { return e1; }
  std::variant<Constant, Negate, Add, Multiply> getOperand2() const { return e2; }
};

struct ExpMaxConstVisitor {
  int operator()(Constant c) const { return c.value(); }
  int operator()(Negate c) const { return c.value(); }
  int operator()(Add a) const {
    return std::max(std::visit(*this, a.getOperand1()),
                    std::visit(*this, a.getOperand2()));
  }
  int operator()(Multiply m) const {
    return std::max(std::visit(*this, m.getOperand1()),
                    std::visit(*this, m.getOperand2()));
  }
};

Using Function Objects as Visitors

我把上面的代碼簡化一下讓它編譯通過,來介紹其中的Visitor

class Constant {
  int v;

public:
  Constant(int v) : v(v) {}
  int value() const { return v; }
};

class Negate {
  int v;

public:
  Negate(int v) : v(v) {}
  int value() const { return v; }
};

class Add {
  int e1;
  int e2;

public:
  Add(int e1, int e2) : e1(e1), e2(e2) {}
  int getOperand1() const { return e1; }
  int getOperand2() const { return e2; }
};

class Multiply {
  int e1;
  int e2;

public:
  Multiply(int e1, int e2) : e1(e1), e2(e2) {}
  int getOperand1() const { return e1; }
  int getOperand2() const { return e2; }
};

struct ExpMaxConstVisitor {
  int operator()(Constant c) const { return c.value(); }
  int operator()(Negate c) const { return c.value(); }
  int operator()(Add a) const {
    // 如果std::variant是遞歸的話,這裏本應該是遞歸訪問的
    // return std::max(std::visit(*this, a.getOperand1()),
    //                 std::visit(*this, a.getOperand2()));
    return std::max(a.getOperand1(), a.getOperand2());
  }
  int operator()(Multiply m) const {
    // 如果std::variant是遞歸的話,這裏本應該是遞歸訪問的
    // return std::max(std::visit(*this, m.getOperand1()),
    //                 std::visit(*this, m.getOperand2()));
    return std::max(m.getOperand1(), m.getOperand2());
  }
};

int main() {
  std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
  std::cout << std::visit(ExpMaxConstVisitor(), exp) << std::endl;
  return 0;
}

上述代碼中的ExpMaxConstVisitor是最常見的std::variant visitor類型。

The call of visit() is a compile-time error if not all possible types are supported by an operator() or if the call is ambiguous.

std::visit也提供了type-safe的保證,如果沒有保證窮盡所有的case,compiler可能會拋出下面的error message。

`std::visit` requires the visitor to beexhaustive

Using Generic Lambdas as Visitors

generic lambda是C++14提出來的特性,例如傳統的lambda

auto add = [](int a, int b) -> int { return a + b; }

generic lambda如下所示:

auto add = [](auto a, auto b) { return a + b; }

如果加上-std=c++11 option則會拋出下面的錯誤。


#1 with x86-64 clang 9.0.0
<source>:3:19: error: 'auto' not allowed in lambda parameter

    auto add = [](auto a, auto b) { return a + b ;};

                  ^~~~

<source>:3:27: error: 'auto' not allowed in lambda parameter

    auto add = [](auto a, auto b) { return a + b ;};

                          ^~~~

<source>:4:15: error: invalid operands to binary expression ('std::ostream' (aka 'basic_ostream<char>') and 'void')

    std::cout << add(10, 20) << std::endl;

通過c++ insights,我們知道generic lambda add相當於下面的代碼。

class __lambda_3_16 {
public: 
  template<class type_parameter_0_0, class type_parameter_0_1>
  inline /*constexpr */ auto operator()(type_parameter_0_0 a, type_parameter_0_1 b) const {
    return a + b;
  }
private: 
  template<class type_parameter_0_0, class type_parameter_0_1>
  static inline auto __invoke(type_parameter_0_0 a, type_parameter_0_1 b) {
    return a + b;
  }  
};

使用generic lambda visitor代碼如下:

int main() {
  std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
  std::cout << std::visit([](auto& val) {
    if constexpr(std::is_convertible_v<decltype(val), Constant>) {
      return val.value();
    }
    if constexpr(std::is_convertible_v<decltype(val), Negate>) {
      return val.value();
    }
    if constexpr(std::is_convertible_v<decltype(val), Add>) {
      return std::max(val.getOperand1(), val.getOperand2());
    }
    if constexpr(std::is_convertible_v<decltype(val), Multiply>) {
      return std::max(val.getOperand1(), val.getOperand2());
    }
  }, exp) << std::endl;
  return 0;
}

上面我們使用了compile-time if的特性,也就是c++17提出來的特性if constexpr。同樣的,通過c++ insights展開上述代碼。可以發現這一切都是template argument type deduction實現的。

Using Overloaded Lambdas as Visitors

By using an overloader for function objects and lambdas, you can also define a set of lambdas where the best match is used as a visitor.

template<typename... Ts> struct overload : Ts... {
  using Ts::operator()...;
};
template<typename... Ts>
overload(Ts...) -> overload<Ts...>;

int main() {
  std::variant<Constant, Negate, Add, Multiply> exp(Add(10, 30));
  std::cout << std::visit(overload{
    [](auto a) { return a.value(); },
    [](Add a) {return std::max(a.getOperand1(), a.getOperand2());},
    [](Multiply a) {return std::max(a.getOperand1(), a.getOperand2());}
  }, exp) << std::endl;
  return 0;
}

std::variant現階段還有很多可以提升的地方,例如recursive viariantvisitor的部分寫起來還算簡單。可以減少部分class inheritance的使用。但和functional programming中的pattern matching比較來說,稍微還有點兒複雜,但有終歸比沒有好。

關於recursive variant,還有很多可以講的地方。像Variant design review中介紹的,

Recursive variants are variants that (conceptually) have itself as one of the alternatives. There are good reasons to add support for a recursive variant; for instance to build AST nodes. There are also good reasons not to do so, and to instead use unique_ptr<variant<…>> as an alternative. A recursive variant can be implemented as an extension to variant, see for instance what is done for boost::variant. The proposals does not contain support for recursive variants; they also do not preclude a proposal for them

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章