tf.exports 的定義在 tensorflow.python.util.tf_export.py
主要是爲函數或者類增加新的__dict__
tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
api_export 是一個類,主要是將symbols 導入到Tesorflow API。
TENSORFLOW_API_NAME = “tensorflow”
類api_export 內部實現了一個__call__ 函數
def __call__(self, func):
"""Calls this decorator.
Args:
func: decorated symbol (function or class).
Returns:
The input function with _tf_api_names attribute set.
Raises:
SymbolAlreadyExposedError: Raised when a symbol already has API names
and kwarg `allow_multiple_exports` not set.
"""
api_names_attr = API_ATTRS[self._api_name].names
api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
delattr(undecorated_f, api_names_attr)
delattr(undecorated_f, api_names_attr_v1)
_, undecorated_func = tf_decorator.unwrap(func)
# set_attr 首先檢查api_names_attr是不是在undecorated_func的__dict__
# 如果已經存在,會拋出SymbolAlreadyExposedError
# 這裏的 api_names_attr 和 api_names_attr_v1
'''
# Attribute values must be unique to each API.
API_ATTRS = {
TENSORFLOW_API_NAME: _Attributes(
'_tf_api_names',
'_tf_api_constants'),
ESTIMATOR_API_NAME: _Attributes(
'_estimator_api_names',
'_estimator_api_constants'),
KERAS_API_NAME: _Attributes(
'_keras_api_names',
'_keras_api_constants')
}
API_ATTRS_V1 = {
TENSORFLOW_API_NAME: _Attributes(
'_tf_api_names_v1',
'_tf_api_constants_v1'),
ESTIMATOR_API_NAME: _Attributes(
'_estimator_api_names_v1',
'_estimator_api_constants_v1'),
KERAS_API_NAME: _Attributes(
'_keras_api_names_v1',
'_keras_api_constants_v1')
}
對於TENSORFLOW_API_NAME ,則是 _tf_api_names 、 _tf_api_names_v1
'''
self.set_attr(undecorated_func, api_names_attr, self._names)
self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
return func
例子:
class TestClassA(object):
def __init__(self):
self.a = 12
export_decorator_a = tf_export.tf_export('TestClassA', 'dadsa')
print("1:", TestClassA.__dict__)
a = export_decorator_a(TestClassA)
print("2:",TestClassA.__dict__)
# 被修飾後的類和原來的類是一致的
print(id(a))
print(id(TestClassA))
輸出:
1: {'__module__': '__main__', '__init__': <function TestClassA.__init__ at 0x7f5c40211f28>, '__dict__': <attribute '__dict__' of 'TestClassA' objects>, '__weakref__': <attribute '__weakref__' of 'TestClassA' objects>, '__doc__': None}
2: {'__module__': '__main__', '__init__': <function TestClassA.__init__ at 0x7f5c40211f28>, '__dict__': <attribute '__dict__' of 'TestClassA' objects>, '__weakref__': <attribute '__weakref__' of 'TestClassA' objects>, '__doc__': None, '_tf_api_names': ('TestClassA', 'dadsa'), '_tf_api_names_v1': ('TestClassA', 'dadsa')}
)
93887463023576
93887463023576