在Python中,不仅对象支持多态,类也支持多态。
多态,使得继承体系中的多个类能够以各自独有的方式来实现某个方法。这些类,都满足相同的接口或继承自相同的抽象类,但却有着各自不同的功能。
案例1:实现MapReduce
流程,定义表示输入数据的公共基类。
class InputData(object):
def read(self):
raise NotImplementedError
现在编写InputData
类的具体子类,实现磁盘文件中的数据读取。
class PathInputData(InputData):
def __init__(self,path) -> None:
super().__init__()
self.path = path
def read(self):
return open(self.path).read()
同时,我们可能需要多个像PathInputData
这样的类充当InputData
的子类,以实现多个标准接口的read
方法,比如可以实现网络读取并解压数据。
此外,我们还需要为MapReduce
工作线程定义一套类似的抽象接口,以便于处理输入的数据。
class Worker(object):
def __init__(self,input_data) -> None:
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self,other):
raise NotImplementedError
下面定义具体的子类,以实现我们想要的MapReduce
功能。本例实现简单的换行符计数器。
class LineCountWorker(Worker):
def map(self):
data = self.input_data.read()
self.result += data.count('\n')
def reduce(self.other):
self.result += other.result
在实现了MapReduce
的各个组件后,需要将各个组件串联起来,以实现整个流程,通常的方法是编写辅助函数将这些类对象联系起来。
##生成器函数,生成数据
def generate_inputs(data_dir):
for name in os.listdir(data_dir):
yield PathInputData(os.path.join(data_dir,name))
### 创建多个worker对象
def create_workers(input_list):
workers = []
for input_data in input_list:
workers.append(LineCountWorker(input_data))
return workers
### 将每个对象分发到各个线程执行
def execute(workers):
threads = [Thread(target=w.map for w in workers)]
for thread in threads:thread.start()
for thread in threads:thread.join()
first,rest = workers[0],workers[1:]
for worker in rest:
first.reduce(worker)
return first.result
### 执行mapreduce函数
def mapreduce(data_dir):
inputs = generate_inputs(data_dir)
workers = create_workers(inputs)
return execute(workers)
整个调用的流程图如下所示:
上述写法存在的主要问题是MapReduce
函数不够通用。如果编写其他的InputData
或Worker
子类,那就得重写generate_inputs
、create_workers
和mapreduce
函数。
其实,在C++
或者Java
当中,可以通过构造函数的重载来解决,但是在Python中只允许名为__init__
的构造器方法,所以不能提供多个不同输入参数的__init__
方法。
@classmethod
解决这个问题最好的方法,是使用@classmethod
形式的多态,即类方法的多态机制。
首先修改InputData
类,为它添加通用的generate_inputs
类方法,该方法会根据通用的接口来创建新的InputData
实例
class GenericInputData(object):
def read(self):
raise NotImplementedError
@classmethod
def generate_inputs(cls,config):
raise NotImplementerError
## 修改子类
class PathInputData(GenericInputData):
#...
def read(self):
return open(self.path).read()
@classmethod
def generate_inputs(cls,config):
data_dir = config['data_dir']
for name in os.listdir(data_dir):
yield cls(os.path.join(data_dir,name))
按照同样的方法修改Worker
类,并添加create_workers
方法。
class GenericWorker(object):
#...
def map(self):
raise NotImplementedError
def reduce(self,other):
raise NotImplementedError
@classmethod
def create_inputs(cls,input_class,config):
workers = []
for input_data in input_class.generate_inputs(config):
worker.append(cls(input_data))
return workers
上述代码的重点在于input_class.generate_inputs
的调用,是类级别的多态调用,同时GenericWorker
对象通过cls形式构造。
具体的GenericWorker
子类,只需修改继承的父类即可。
class LineCountWorker(GenericWorker):
#...
最后,重写mapreduce
函数:
def mapreduce(worker_class,input_class,config):
wokers = worker_class.create_workers(input_class,config)
return execute(workers)
if __name__=='__main__':
config = {'data_dir':'./data/'}
result = mapreduce(LineCountWorker,PathInputData,config)
最后,我们可以编写GenericInputData
和GenericWorker
的其他子类,而无需修改函数代码,仅需要改动的是mapreduce
的输入参数即可。