30.5.3. 通过@classmethod多态来构造同一体系中的各类对象

在Python中,不仅对象支持多态,类也支持多态。

这里说的对象支持多态,可理解为在超类对象上面调用实例方法,实际触发的是子类对象的同名实例方法; 类支持多态,可理解为在超类上面调用类方法,实际触发的是子类的同名类方法。

多态机制使同一体系中的多个类可以按照各自独有的方式来实现同一个方法,这意味着这些类都可以满足同一套接口,或者都可以当作某个抽象类来使用,同时,它们又能在这个前提下,实现各自的功能。

#!/usr/bin/env python
# -*- coding:utf8 -*-
# auther: 18793
# Date:2021/11/4 21:04
# filename: classmethod_sample01.py
class InputData:
    def read(self):
        raise NotImplementedError


class PathInputData(InputData):
    def __init__(self, path):
        super().__init__()
        self.path = path

    def read(self):
        with open(self.path) as f:
            return f.read()


class Worker:
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None

    def map(self):
        raise NotImplementedError

    def reduce(self, other):
        raise NotImplementedError


class LineCountWorker(Worker):
    def map(self):
        data = self.input_data.read()
        self.result = data.count('\n')

    def reduce(self, other):
        self.result += other.result


def generate_inputs(data_dir):
    for name in os.listdir(data_dir):
        yield PathInputData(os.path.join(data_dir, name))


def create_workers(input_list):
    workers = []
    for input_data in input_list:
        workers.append(LineCountWorker(input_data))
    return workers


from threading import Thread


def execute(workers):
    threads = [Thread(target=w.map) for w in workers]
    for thread in threads: thread.start()
    for thread in threads: thread.join()

    first, *rest = workers
    for worker in rest:
        first.reduce(worker)
    return first.result


def mapreduce(data_dir):
    inputs = generate_inputs(data_dir)
    workers = create_workers(inputs)
    return execute(workers)


import os
import random


def write_test_files(tmpdir):
    os.makedirs(tmpdir)
    for i in range(100):
        with open(os.path.join(tmpdir, str(i)), 'w') as f:
            f.write('\n' * random.randint(0, 100))


tmpdir = 'test_inputs'
write_test_files(tmpdir)

result = mapreduce(tmpdir)
print(f'There are {result} lines')

然后这样做有个大问题,就是mapreduce函数根本不通用。假如要使用其他的InputData或Worker子类,那就必须修改generate_inputs、create_workers与mapreduce代码。

这个问题的根本原因在于,构造对象的办法不够通用。Python中最好能够通过类方法多态(class method polymorphism)来解决。这种多态与InputData.read所体现的实例方法多态(instance method polymorphism)很像,只不过它针对的是类,而不是这些类的对象。

我们现在运用方法多态来实现MapReduce流程所用到的这些类。首先改写InputData类,把generate_inputs方法放到该类里面并声明成通用的@classmethod,这样它所欲子类都可以通过同一个接口来新建具体的InputData实例。

#!/usr/bin/env python
# -*- coding:utf8 -*-
# auther: 18793
# Date:2021/11/4 21:04
# filename: classmethod_sample01.py
class GenericInputData:
    def read(self):
        raise NotImplementedError

    @classmethod
    def generate_inputs(cls, config):
        raise NotImplementedError


class PathInputData(GenericInputData):
    def __init__(self, path):
        super().__init__()
        self.path = path

    def read(self):
        with open(self.path) as f:
            return f.read()

    @classmethod
    def generate_inputs(cls, config):
        data_dir = config['data_dir']
        for name in os.listdir(data_dir):
            yield cls(os.path.join(data_dir, name))


class GenericWorker:
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None

    def map(self):
        raise NotImplementedError

    def reduce(self, other):
        raise NotImplementedError

    @classmethod
    def create_workers(cls, input_class, config):
        workers = []
        for input_date in input_class.generate_inputs(config):
            workers.append(cls(input_date))
        return workers


class LineCountWorker(GenericWorker):
    def map(self):
        data = self.input_data.read()
        self.result = data.count('\n')

    def reduce(self, other):
        self.result += other.result


def generate_inputs(data_dir):
    for name in os.listdir(data_dir):
        yield PathInputData(os.path.join(data_dir, name))


from threading import Thread


def execute(workers):
    threads = [Thread(target=w.map) for w in workers]
    for thread in threads: thread.start()
    for thread in threads: thread.join()

    first, *rest = workers
    for worker in rest:
        first.reduce(worker)
    return first.result


def mapreduce(worker_class, input_class, config):
    workers = worker_class.create_workers(input_class, config)
    return execute(workers)


import os
import random


def write_test_files(tmpdir):
    os.makedirs(tmpdir)
    for i in range(100):
        with open(os.path.join(tmpdir, str(i)), 'w') as f:
            f.write('\n' * random.randint(0, 100))


tmpdir = 'test_inputs'
write_test_files(tmpdir)
config = {"data_dir": tmpdir}
result = mapreduce(LineCountWorker, PathInputData, config)
print(f"There are {result} lines")

这套方案让我们能够随意编写其他的GenericInputData与GenericWorker子类,而不用再花时间去调整它们之间的拼接代码(glue code)。

要点:

Python只允许每个类有一个构造方法,也就是__init__方法。

如果想在超类中用通用的代码构造子类实例,那么可以考虑定义@classmethod方法,并在里面用cls(…)的形式构造具体的子类对象。通过类方法多态机制,我们能够以通用的形式构造并拼接具体的子类对象。