mxnet symbol 解析

mxnet symbol類定義:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol/symbol.py

對於一個symbol,可分爲non-grouped和grouped。且symbol具有輸出,和輸出屬性。比如,對於Variable而言,其輸入和輸出就是它自己。對於c = a+b,c的內部有個_plus0 symbol,對於_plus0這個symbol,它的輸入是a,b,輸出是_plus0_output。

class Symbol(SymbolBase):
    """Symbol is symbolic graph of the mxnet."""
    # disable dictionary storage, also do not have parent type.
    # pylint: disable=no-member

其中,Symbol還不是最基礎的類,Symbol類繼承了SymbolBase這個類。
而SymbolBase這個類實際是在

https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol/_internal.py

中引用的,通過以下方式引用:

from .._ctypes.symbol import SymbolBase, _set_symbol_class, _set_np_symbol_class

而SymbolBase的定義是在:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/_ctypes/symbol.py
這裏暫時先不管SymbolBase,這應該是是python調用c++接口創建的一個類。

回到Symbol中來,對於mxnet符號式編程而言,定義的任何網絡,或者變量,都是symbol類型,所以,瞭解這個類就顯得很重要。

Symbol類中有幾類函數:
1、普通函數
2、__xx__ 函數
3、@property 修飾的函數
4、函數名爲xx,實際調用op.xx的函數

1、普通函數
attr
根據key返回symbol對應的屬性字符串,只對non-grouped symbols起作用。

    def attr(self, key):
        """Returns the attribute string for corresponding input key from the symbol.

list_attr
得到symbol的所有屬性

    def list_attr(self, recursive=False):
        """Gets all attributes from the symbol.

attr_dict
遞歸的得到symbol和孩子的屬性

    def attr_dict(self):
        """Recursively gets all attributes from the symbol and its children.
        Example
        -------
        >>> a = mx.sym.Variable('a', attr={'a1':'a2'})
        >>> b = mx.sym.Variable('b', attr={'b1':'b2'})
        >>> c = a+b
        >>> c.attr_dict()
        {'a': {'a1': 'a2'}, 'b': {'b1': 'b2'}}

_set_attr
通過key-value方式,對attr進行設置

    def _set_attr(self, **kwargs):
        """Sets an attribute of the symbol.
        For example. A._set_attr(foo="bar") adds the mapping ``"{foo: bar}"``
        to the symbol's attribute dictionary.

get_internals
獲取symbol的所有內部節點symbol,是一個group類型(包括輸入,輸出節點symbol)。如果我們想階段一個network,應該獲取它某內部節點的輸出,這樣才能作爲新增加的symbol的輸入。

    def get_internals(self):
        """Gets a new grouped symbol `sgroup`. The output of `sgroup` is a list of
        outputs of all of the internal nodes.

get_children
獲取當前symbol輸出節點的inputs

    def get_children(self):
        """Gets a new grouped symbol whose output contains
        inputs to output nodes of the original symbol.

list_arguments
列出當前symbol的所有參數(可以配合call對symbol進行改造)

    def list_arguments(self):
        """Lists all the arguments in the symbol.

list_outputs
列出當前smybol的所有輸出,如果當前symbol是grouped類型,回遍歷輸出每一個symbol的輸出

    def list_outputs(self):
        """Lists all the outputs in the symbol.

list_auxiliary_states
列出symbol中的輔助狀態參數,比如BN

    def list_auxiliary_states(self):
        """Lists all the auxiliary states in the symbol.
        Example
        -------
        >>> a = mx.sym.var('a')
        >>> b = mx.sym.var('b')
        >>> c = a + b
        >>> c.list_auxiliary_states()
        []
        Example of auxiliary states in `BatchNorm`.

list_inputs
列出當前symbol的所有輸入參數,和輔助狀態,等價於 list_arguments和 list_auxiliary_states

    def list_inputs(self):
        """Lists all arguments and auxiliary states of this Symbol.

2、__xx__函數

__repr__
對於gruop symbol,它是沒有name屬性的,print或者回車,結果就是其內部symbol節點的name
在這裏插入圖片描述
__iter__(self):
普通的symbol長度都只有1,只有Grouped 的symbol,長度才大於1:return (self[i] for i in range(len(self)))
算數及邏輯運算:
+,-,*, /,%,abs,**, 取負(-x),==,!=,>,>=,<,<=, # 使用時,要注意Broadcasting 是否支持

    def __abs__(self):
        """x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)"""
        return self.abs()

    def __add__(self, other):
        """x.__add__(y) <=> x+y
    其他   

__copy__和__deep_copy__
通過deep_copy,創建一個深拷貝,返回輸入對象的一個拷貝,包括它當前所有參數的當前狀態,比如weight,bias等
在這裏插入圖片描述
__call__
表示symbol的實例是一個可調用對象。可以返回一個新的symbol,這個symbol繼承了之前symbol的權重啥的,但是和之前的symbol是不同的對象,可以輸入參數對symbol進行組合。

    def __call__(self, *args, **kwargs):
        """Composes symbol using inputs.
        Returns
        -------
            The resulting symbol.
        """
        s = self.__copy__()  #  這裏對symbol實例做了一次深拷貝,返回的新的symbol
        s._compose(*args, **kwargs) # 實際調用的_compose函數
        return s
    # 對當前的symbol進行編譯,返回一個新的symbol,可以指定新symbol的name,其他輸入參數必須是symbol類型
    # 當前symbol的輸入參數,可以通過 .list_arguments()獲取
    def _compose(self, *args, **kwargs):
        """Composes symbol using inputs.
        x._compose(y, z) <=> x(y,z)
        This function mutates the current symbol.
        Example
        -------
        Returns
        -------
            The resulting symbol.
        """
        name = kwargs.pop('name', None)

        if name:
            name = c_str(name)
        if len(args) != 0 and len(kwargs) != 0:
            raise TypeError('compose only accept input Symbols \
                either as positional or keyword arguments, not both')

這裏,我改變了b,將其輸入參數的x的值變爲了tt。
在這裏插入圖片描述

__getitem__
如果symbol的長度只有1,那麼返回的就是它的輸出symbol,如果symbol長度>1,可以通過切片訪問其輸出symbol,返回的也是一個Group symbol。symbol可以分爲non-grouped和grouped。
獲取內部節點symbol還可以輸入str,但輸入的str必須屬於list_outputs(),

    def __getitem__(self, index):
        """x.__getitem__(i) <=> x[i]
        Returns a sliced view of the input symbol.
        Parameters
        ----------
        index : int or str
            Indexing key
        """
        output_count = len(self)
        if isinstance(index, py_slice):
			# 輸入切片
        if isinstance(index, string_types):
            # 輸入字符串
            # Returning this list of names is expensive. Some symbols may have hundreds of outputs
            output_names = self.list_outputs()
            idx = None
            for i, name in enumerate(output_names):
                if name == index:
                    if idx is not None:
                        raise ValueError('There are multiple outputs with name \"%s\"' % index)
                    idx = i
            if idx is None:
                raise ValueError('Cannot find output that matches name \"%s\"' % index)
            index = idx

symbol.py 除了Symbol這個類之外,還有遊離在外的函數:

1def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
        init=None, stype=None, **kwargs):
    """Creates a symbolic variable with specified name.
# for back compatibility
Variable = var  #  調用 mx.sym.var和mx.sym.Variable 等價

2、
def Group(symbols, create_fn=Symbol):
    """Creates a symbol that contains a collection of other symbols, grouped together.
    A classic symbol (`mx.sym.Symbol`) will be returned if all the symbols in the list
    are of that type; a numpy symbol (`mx.sym.np._Symbol`) will be returned if all the
    symbols in the list are of that type. A type error will be raised if a list of mixed
    classic and numpy symbols are provided.
    Example
    -------
    >>> a = mx.sym.Variable('a')
    >>> b = mx.sym.Variable('b')
    >>> mx.sym.Group([a,b])
    <Symbol Grouped>
    Parameters
    ----------
    symbols : list
        List of symbols to be grouped.

3def load(fname):
    """Loads symbol from a JSON file.
     You also get the benefit being able to directly load/save from cloud storage(S3, HDFS).

    Returns
    -------
    sym : Symbol
        The loaded symbol.
    See Also
    --------
    Symbol.save : Used to save symbol into file.
# 輸入文件可以是hdfs文件
4、
數學相關函數,輸入可爲scalar或者是symbol
def pow(base, exp):
    """Returns element-wise result of base element raised to powers from exp element.
	base 和 exp可以是數字或者symbol
# def power(base, exp):  #  實際調用pow
def maximum(left, right):
def minimum(left, right):
def hypot(left, right):  #  返回直角三角形的斜邊
def eye(N, M=0, k=0, dtype=None, **kwargs):
    """Returns a new symbol of 2-D shpae, filled with ones on the diagonal and zeros elsewhere.  #  返回2D shape的symbol,對角線爲1,其餘位置爲0
def zeros(shape, dtype=None, **kwargs):
    """Returns a new symbol of given shape and type, filled with zeros.  # 返回一個shape的全0 symbol
def ones(shape, dtype=None, **kwargs):
    """Returns a new symbol of given shape and type, filled with ones.
def full(shape, val, dtype=None, **kwargs):
    """Returns a new array of given shape and type, filled with the given value `val`.
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
    """Returns evenly spaced values within a given interval.
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
    """Returns evenly spaced values within a given interval.
def linspace(start, stop, num, endpoint=True, name=None, dtype=None):
    """Return evenly spaced numbers within a specified interval.
def histogram(a, bins=10, range=None, **kwargs):
    """Compute the histogram of the input data.
def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):
    """Split an array into multiple sub-arrays.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章