Python协程

Author Avatar
patrickcty 4月 27, 2018

generator

如果把一个可迭代对象比作一辆装有电池的四驱车,那么四驱车的马达会在电池电量耗尽之时才会停止;但是 generator 则像是一个发条小车,它每次都只会在拧动发条之后才会前进相应的距离。这个发条就是“yield” 关键字。

作为可迭代对象

generator 也是一个可迭代对象,它可以通过类似于列表生成式的方式来定义:

g = (x for x in range(10))

需要注意的是这里的括号是小括号,如果是中括号的话生成的就是一个普通的列表了。

在这里获得 g 的所有结果有两种方法,一种是多次使用十次 next() 来依次生成每个值:

>>> next(g)
0
>>> next(g)
1
>>> next(g)
2
>>> next(g)
3
>>> next(g)
4
>>> next(g)
5
>>> next(g)
6
>>> next(g)
7
>>> next(g)
8
>>> next(g)
9
>>> next(g)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

如果调用的太多了就会产生错误。

第二种就是直接把 g 作为可迭代对象在 for in 中循环:

for num in g:
    print(num)

作为函数

我们还可以把一个函数作为 generator 来使用,这里就需要使用我们的 yield 发条来进行改造了

def count(mmax):
    for i in range(mmax):
        yield i

作为函数也同样有两种用法:使用 next 调用和直接放入 for in 循环中:

# 方法一
c = count(10)
next(c)
next(c)
...

# 方法二
for num in count(10):
    print(num)

这里最令人费解的就是 yield 关键字了,在 generator 运行的过程中,当执行到 yield 的时候,程序就会暂停下来,返回 yield 后面对应的值。只有调用 next() 之后才会继续从刚才停下的地方开始执行直到遇到下一个 yield 又再次停止。

Coroutine

Coroutine 叫做协程,又叫做微线程。协程的作用,是在执行函数A时,可以随时中断,去执行函数B,然后中断继续执行函数A(可以自由切换)。但这一过程并不是函数调用(没有调用语句),这一整个过程看似像多线程,然而协程只有一个线程执行【1】

Coroutine 分为三个阶段:

  • generator 变形 yield/send()
  • @asyncio.coroutine 以及 yield from
  • async/await

generator 变形 yield/send()

这一阶段和 generator 不同的是对 yield 的用法进行了扩展,除了 yield 除了可以返回数据之外还可以接收数据。

# getData 是接收到的数据,returnData 是发送出去的数据
getData = yield returnData

我们知道 generator 是通过 next() 或者直接迭代来得到下一个结果,而这里则是通过 coroutine.send() 来得到下一个结果,或者更准确的来说是让子程序继续往后执行,但要注意的是在初始化也就是第一次调用的时候要使用 coroutine.send(None),这里其实是启动 generator。

def funA():
    while True:
        r = yield 'fine, thank you!'
        print(r)
        
def funB(c):
    # 启动 generator
    c.send(None)
    i = 0
    while i < 3:
        r = c.send('how are you?')
        print(r)
        i += 1
    # 关闭 funA
    c.close()
    
c = funA()
funB(c)

运行结果为:
how are you?
fine, thank you!
how are you?
fine, thank you!
how are you?
fine, thank you!

在这里 funB 通过 send 给 funA 发送信息,之后 funB 通过 yield 回复信息给 funA,但这两者在发送的同时(其实是之后)也可以读取到相应的信息。

运行的流程是这样的:

  • funB 启动了 generator
  • funA 运行到 yield 处向 funA 发送 ‘fine, thank you!’,然后切换到 funB
  • funB 运行到 send() 向 funA 发送了 ‘how are you?’,切换回 funA
  • funA 从 yield 接收到 funB 的消息,继续运行到 yield 处向 funA 发送 ‘fine, thank you!’,然后切换到 funB
  • funB 接收到 funA 的消息,继续运行到 send() 向 funA 发送了 ‘how are you?’,切换回 funA
  • 重复以上过程直到 funB 使用 c.close() 关闭 funA,运行结束

需要注意的是其实 send() 和 yield 的作用是非常类似的,只不过 send() 是主动方,yield 是被动方。

@asyncio.coroutine 以及 yield from

在 send() 和 yield 部分是由明显的主动被动关系,但是在这里则更多的是一种并列关系,这里要处理的问题是在一方进行 io 操作时另一方能充分利用这段空闲的时间。

import asyncio

@asyncio.coroutine
def funA():
    print('hello, world!')
    yield from asyncio.sleep(1)
    print('hello, again!')
    
@asyncio.coroutine
def funB():
    print('are you ok?')
    yield from asyncio.sleep(1)
    print('hello, thank you! thank you very much!')
    
loop = asyncio.get_event_loop()
tasks = [funA(), funB()]
loop.run_until_complete(asyncio.wait(tasks))
loop.close()

运行结果:
are you ok?
hello, world!
hello, thank you! thank you very much!
hello, again!

这里 funA funB 是作为两个互相独立的 task 来运行的,yield 后面紧接着的是一个协程的函数过程,这就是说在这里 CPU 空闲出来了,于是切换给另一个 task 来执行。注意的是这里的 funA() funB() 其实都是 generator,因而并不会马上执行。

这里的运行过程是这样的:

  • 我们首先获得了当前的 event loop,然后将 funA funB 两个 task 放在其中,通过 asyncio.wait() 来执行
  • 首先被调用的是 funB,它运行到 yield 处阻塞,此时切换到 funA 执行
  • 同理,funA 也是运行到 yield 处阻塞,此时切换到 funB 执行
  • funB 继续执行直到完成释放资源,此时再次切换到 funA 执行直到结束
  • 所有任务完成,run_until_complete() 返回,接下来可以关掉 event loop 了

async/await

而 async/await 与 generator 变形 yield/send() 的区别就没有之前两者区别那么大了,async/await 对之前的形式进行了优化,让 coroutine 的代码更简洁易读。

如果要使用 async/await 就只用:

  1. 把 @asyncio.coroutine 替换为 async(在 def 前面)
  2. 把 yield from 替换为 await
import asyncio

async def funA():
    print('hello, world!')
    await asyncio.sleep(1)
    print('hello, again!')
    
async def funB():
    print('are you ok?')
    await asyncio.sleep(1)
    print('hello, thank you! thank you very much!')
    
loop = asyncio.get_event_loop()
tasks = [funA(), funB()]
loop.run_until_complete(asyncio.wait(tasks))
loop.close()

这样一看确实简洁了不少,也不会和 generator 中的 yield 搞混淆了。

总结

使用这些高级方法,我们 Python 程序的效率可以进一步提高,特别是充分地利用了 io 的时间。

更多示例代码

生成器生成杨辉三角

def yanghuiTri(mmax):
    tri = [[1], [1, 1]]
    yield tri[0]
    yield tri[1]
    r = 2
    while r < mmax:
        row = [1]
        
        for i in range(1, r):
            row.append(tri[r - 1][i - 1] + tri[r - 1][i])
        row.append(1)
        tri.append(row)
        yield tri[r]
        r += 1


for l in yanghuiTri(10):
    print(l)

@asyncio.coroutine 获取网页头部

来源

import asyncio

# 如果要使用 async/await 就只用:
# 1. 把 @asyncio.coroutine 替换为 async(在 def 前面)
# 2. 把 yield from 替换为 await


@asyncio.coroutine
def wget(host):
    print('wget %s...' % host)
    # asyncio.open_connection 返回一个 (reader, writer) 的二元组
    # 来建立连接
    connet = asyncio.open_connection(host, 80)
    # yield from 一般是 io 请求,其实更准确的来说也是一个 coroutine,如果不是的话就会报错
    # 此时程序直接进入另一个 task
    # 等执行完返回值后再继续切换回来
    reader, writer = yield from connet
    header = 'GET / HTTP/1.0\r\nHost: %s\r\n\r\n' %host
    # 发送 HTTP 请求
    # 因为传输的是字节,因而要 encode
    writer.write(header.encode('utf-8'))
    # 这里也是 io 请求,清空缓冲区
    yield from writer.drain()
    while True:
        line = yield from reader.readline()
        # 回答报文头部和正文之间间隔了一个 \r\n(字节形式的)
        if line == b'\r\n':
            return
        # Python rstrip() 删除 string 字符串末尾的指定字符(默认为空格)
        # decode 是因为数据是以字节流来传输的
        print('%s header > %s' % (host, line.decode('utf-8'.rstrip())))
    # 断开 TCP
    writer.close()

loop = asyncio.get_event_loop()
tasks = [wget(host) for host in ['www.sina.com.cn', 'www.sohu.com', 'www.163.com']]
# asynico.wait 来让任务阻塞时能唤醒另一个任务
loop.run_until_complete(asyncio.wait(tasks))
loop.close()

asyn/await 获取网页头部

import asyncio

async def wget(host):
    print('wget %s...' % host)
    # asyncio.open_connection 返回一个 (reader, writer) 的二元组
    # 来建立连接
    connet = asyncio.open_connection(host, 80)
    # await 一般是 io 请求
    # 此时程序直接进入另一个 task
    # 等执行完返回值后再继续切换回来
    reader, writer = await connet
    header = 'GET / HTTP/1.0\r\nHost: %s\r\n\r\n' %host
    # 发送 HTTP 请求
    # 因为传输的是字节,因而要 encode
    writer.write(header.encode('utf-8'))
    # 这里也是 io 请求,清空缓冲区
    await writer.drain()
    while True:
        line = await reader.readline()
        # 回答报文头部和正文之间间隔了一个 \r\n(字节形式的)
        if line == b'\r\n':
            return
        # Python rstrip() 删除 string 字符串末尾的指定字符(默认为空格)
        # decode 是因为数据是以字节流来传输的
        print('%s header > %s' % (host, line.decode('utf-8'.rstrip())))
    # 断开 TCP
    writer.close()

loop = asyncio.get_event_loop()
tasks = [wget(host) for host in ['www.sina.com.cn', 'www.sohu.com', 'www.163.com']]
# asynico.wait 来让任务阻塞时能唤醒另一个任务
loop.run_until_complete(asyncio.wait(tasks))
loop.close()