在单元测试时,我们希望测试环境尽可能单纯、可控。因此我们不希望依赖于用户输入,不希望连接无法独占的数据库或者第三方微服务等。这时候,我们需要通 mock 来模拟出这些外部接口。mock 可能是单元测试中最核心的技术。
无论是 unittest 还是 pytest,都是直接或者间接使用了 unittest 中的 mock 模块。所以,当你遇到 mock 相关的问题,请参阅 mock。我们接下来关于 mock 的介绍,也将以 Unittest 中的 mock 为主。不过,两个框架的 mock,在基本概念上都是相通的。
unittest.mock 模块提供了最核心的 Mock 类。当我们用 Mock 类的一个实例来替换被测试系统的某些部分之后,我们就可以对它的使用方式做出断言。这包括检查哪些方法(属性)被调用以及调用它们的参数如何。我们还可以设定返回值或者令其抛出异常,以改变执行路径。
除此之外,mock 模块还提供了 patch 方法和 MagicMock 子类。MagicMock 区别于 Mock 的地方在于,它自动实现了对 Python 中类对象中的魔法函数的 mock(这是它的名字的来源!),比如__iter__等。patch 则是一个带上下文管理的工具,它能自动复原我们对系统的更改。
最基础的 mock 的概念可以通过下面的代码得到演示:
# 示例 7 - 9
from unittest.mock import MagicMock
thing = ProductionClass()
thing.method = MagicMock(return_value=3)
thing.method(3, 4, 5, key='value')
thing.method.assert_called_with(3, 4, 5, key='value')
这段代码假设我们有一个被测试类 ProductionClass,当我们调用它的 method 方法时,它有一些不便在单元测试下运行的情况(比如需要连接数据库),因此,我们希望能跳过对它的调用,而直接返回我们指定的一些值。
在这里我们能拿到 ProductionClass 实例对像的引用,所以,我们可以直接修改它的 method 属性,使之指向一个 MagicMock 对象。MagicMock 对象有一些重要的属性和方法。
这里出现的 return_value 是第一个重要的属性。它的意思时,当被替换的对象(这里是 method)被调用时,返回值应该是 3。与之类似的另一个属性是 side_effect。它同样也在 mock 被调用时,返回设置的值。但 return_value 与 side_effect 有重要区别:两者的返回值都可以设置为数组(或者其它可迭代对象),但通过 side_effect 设置返回值时,每次调用 mock,它都返回 side_effect 中的下一个迭代值;而 return_value 则会将设置值全部返回。另外,如果两者同时设置,side_effect 优先返回。请看下面的示例:
# 示例 7 - 10
import unittest.mock
side_effect = [1, 2, unittest.mock.DEFAULT, 4, 5]
m = unittest.mock.Mock(return_value="foo", side_effect=side_effect)
for _ in side_effect:
print(m())
输出结果将是:
1
2
foo
4
5
我们给 side_effect 设置了 5 个值,在 5 次重复测试时,它分别依次返回下一个迭代值。注意这里我们通过 unittest.mock.DEFAULT,来让其中一次迭代,返回了 return_value 的设置值。当然,本质上,这仍然是对 side_effect 的一个迭代结果。
这里还出现了它的一个重要方法,assert_called_with,即检查被替换的方法是否被以期望的参数调用了。除此之外,还可以断言被调用的次数,等等。
这个例子非常简单。但它也演示了使用 Mock 的精髓,即生成 Mock 实例,设置行为(比如返回值),替换生产系统中的对象(方法、属性等),最后,检验结果。
很多时候,我们会通过 patch 的方式来使用 mock。又有两种主要的方式:
假如我们有一个文件系统相关的操作,为了正常运行,必须在测试环境下构建目录,增加某些文件。为了简单起见,我们希望通过 mock 来模拟这个环境。
# 示例 7 - 11
import os
# FUNCTION UNDER TEST
class Foo:
def get_files(self, dir_: str):
return os.list_dir(dir_)
# TESTING CODE
from unittest.mock import patch
from unittest import TestCase
class FooTest(TestCase):
@patch('__main__.Foo.get_files')
def test_get_files(self, mocked):
mocked.return_value = ["readme.md"]
foo = Foo()
self.assertListEqual(foo.get_files(), ["readme.md"])
test = FooTest()
test.test_get_files()
我们对关键代码进行一些解释。首先,通过装饰器语法进行 mock 时,我们的测试函数会多一个参数(这里是 mocked,但名字可以由我们任意指定)。这里使用多个 patch 装饰器也是可以的,每增加一个装饰器,测试函数就会多增加一个参数。
其次,我们要对 Foo.get_files 进行 mock,但我们在 Foo.get_files 之前,加上了一个__main__的前缀。这是由于类 Foo 的定义处在顶层模块中。在 Python 中,任何一个符号(类、方法或者变量)都处在某个模块(module)之下。如果这段代码存为磁盘文件 foo.py,那么模块名就是 foo;我们在别的模块中引入 Foo.get_files 时,应该使用 foo.Foo.get_files。但在这里,由于我们是同模块引用,因此它的前缀是__main__。
!!! info
使用 mock 的关键,是要找到引用被 mock 对象的正确方式。在 Python 中,一切都是对象。这些对象通过具有层次结构的命名空间来进行寻址。以 patch 方法为例,它处在 mock 模块之中,而 mock 模块又是包 unittest 的下级模块,因此,我们就使用 unittest.mock.patch 来引用它,这也与导入路径是一致的。
但是,像这里的脚本,如果一个对象不是系统内置对象,又不存在于任何包中,那么它的名字空间就是__main__,正如这里的示例__main__.Foo 一样。关于寻址,还存在其它的情况,我们会在后面介绍 builtin 对象以及错误的引用那两节中进行介绍。
通过装饰器语法传入进来的 mock 对象,它的行为是未经设置的。因此,我们要在这里先设置它的返回值,然后再调用业务逻辑函数 foo.get_files – 由于它已经被 mock 了,所以会返回我们设置的返回值。
当我们通过装饰器来使用 mock 时,实际上它仍然是有上下文的,在函数退出之后,mock 对系统的更改就复原了。但是,有时候我们更希望使用代码块级别的 patch,一方面可以更精准地限制 mock 的使用范围,另一方面,它的语法会更简练,因为我们可以一行代码完成 mock 行为的设置。
# 示例 7 - 12
import os
# FUNCTION UNDER TEST
class Foo:
def get_files(self, dir_: str):
return os.list_dir(dir_)
# TESTING CODE
from unittest.mock import patch
from unittest import TestCase
class FooTest(TestCase):
def test_get_files(self):
with patch('__main__.Foo.get_files', return_value=["readme.md"]):
foo = Foo()
self.assertListEqual(foo.get_files(), ["readme.md"])
test = FooTest()
test.test_get_files()
这里仅用一行代码就完成了替换和设置。
在实践中,使用 mock 可能并不像看起来那么容易。有一些情景对初学者而言会比较难以理解。一旦熟悉之后,你会发现,你对 Python 的底层机制,有了更深入的理解。下面,我们就介绍这些场景下如何使用 mock。
前面的例子中,我们给 patch 传入的 target 是一个字符串,显然,在 patch 作用域内,所有的新生成的对象都会被 patch。如果在 patch 之前,对象已经生成了,我们则需要使用patch.object
来完成 patch。这样做的另一个好处是,我们可以有选择性地 patch 部分对象。
# 示例 7 - 19
def bar():
logger = logging.getLogger(__name__)
logger.info("please check if I was called")
root_logger = logging.getLogger()
root_logger.info("this is not intercepted")
# TEST_FOO.PY
from sample.core.foo import bar
logger = logging.getLogger('sample.core.foo')
with mock.patch.object(logger, 'info') as m:
bar()
m.assert_called_once_with("please check if I was called")
在 bar 方法里,两个 logger(root_logger 和’sample.core.foo’对应的 logger) 都被调用,但我们只拦截了后一个 logger 的info
方法,结果验证它被调用,且仅被调用一次。
这里要提及 pytest 中 mocker.patch 与 unitest.mock.patch 的一个细微差别。后者进行 patch 时,可以返回 mock 对象,我们可以通过它进行更多的检查(见上面示例代码中的第 14,16 行);但 mocker.patch 的返回值是 None。
从 3.8 起,unittest.mock 一般就不再区分同步和异步对象,比如:
# FUNCTION UNDER TEST
class Foo:
async def bar():
pass
# TESTING CODE
class FooTest(TestCase):
async def test_bar(self):
foo = Foo()
with patch("__main__.Foo.bar", return_value="hello from async mock!"):
res = await foo.bar()
print(res)
test = FooTest()
await test.test_bar()
原函数 bar 的返回值为空。但输出结果是 “hello from async mock”,说明该函数被 mock 了。
被 mock 的方法 bar 是一个异步函数,如果我们只需要 mock 它的返回值的话,仍然是用同样的方法,直接给 return_value 赋值就好。如果我们要将其替换成另一个函数,也只需要将该函数声明成为异步函数即可。
但是,如果我们要 mock 的是一个异步的生成器,则方法会有所不同:
# FUNCTION UNDER TEST
from unittest import mock
class Foo:
async def bar():
for i in range(5):
yield f"called {i}th"
# TESTING CODE
class FooTest(TestCase):
async def test_bar(self):
foo = Foo()
with mock.patch(
"__main__.Foo.bar"
) as mocked:
mocked.return_value.__aiter__.return_value = [0, 2, 4, 6, 8]
print([i async for i in foo.bar()])
test = FooTest()
await test.test_bar()
理解这段代码的关键是,我们要 mock 的对象是 bar 方法,它的返回值(即 mocked.return_value)是一个 coroutine。我们需要对该 coroutine 的__aiter__方法设置返回值,这样才能得到正确的结果。此外,由于__aiter__本身就是迭代器的意思,所以,即使我们设置它的 return_value,而不是 side_effect 为一个列表,它也会按次返回迭代结果,而不是整个 list。这是与我们前面介绍 return_value 和 side_effect 的区别时所讲的内容相区别的。
同样需要特别注意的是 async with 方法。你需要 mock 住它的__aexit__,将其替换成你要实现的方法。
如果我们有一个程序,读取用户从控制台输入的参数,根据该参数进行计算。显然,我们需要 Mock 用户输入,否则单元测试没法自动化。
在 Python 中,接受用户控制台输入的函数是 input。要 mock 这个方法,按照前面学习中得到的经验,我们需要知道它属于哪个名字空间。在 Python 中,像 input, open, eval 等一类的函数大约有 80 个左右,被称为 builtin(内置函数)。
在 mock 它们时,我们使用 builtins 名字空间来进行引用:
with patch('builtins.input', return_value="input is mocked"):
user_input = input("please say something:")
print(user_input)
执行上述代码时,用户并不会有机会真正输入数据,input 方法被 mock,并且会返回"input is mocked"。
有时候我们会在代码中,通过 datetime.datetime.now() 来获取系统的当前时间。显然,在不同的时间测试,我们会得到不同的取值,导致测试结果无法固定。因此,这也是需要被 mock 的对象。
要实现对这个方法的 mock,可能比我们一开始以为的要难一些。我们的推荐是,使用 freezegun 这个库,而避开自己去 mock 它。
# 请使用 PYTEST 来运行,或者自行改写为 UNITTEST
from freezegun import freeze_time
import datetime
import unittest
# FREEZE TIME FOR A PYTEST STYLE TEST:
@freeze_time("2012-01-14")
def test():
assert datetime.datetime.now() == datetime.datetime(2012, 1, 14)
def test_case2():
assert datetime.datetime.now() != datetime.datetime(2012, 1, 14)
with freeze_time("2012-01-14"):
assert datetime.datetime.now() == datetime.datetime(2012, 1, 14)
assert datetime.datetime.now() != datetime.datetime(2012, 1, 14)
注意 Python 的时间库很多,如果您使用的是其它的库来获取当前时间,则 freeze_gun 很可能会不起作用。不过,对第三方的时间库,一般很容易实现 mock。
假设我们有一个爬虫在抓取百度的热搜词。它的功能主要由 crawl_baidu 来实现。我们另外有一个函数在调用它,以保存 crawl_baidu 的返回结果。我们想知道,如果 crawl_baidu 中抛出异常,那么调用函数是否能够正确处理这种情况。
这里的关键是,我们要让 crawl_baidu 能抛出异常。当然,我们不能靠拔网线来实现这一点。
import httpx
from httpx import get, ConnectError
from unittest.mock import patch
from unittest import TestCase
def crawl_baidu():
return httpx.get("https://www.baidu.com")
class ConnectivityTest(TestCase):
def test_connectivity(self):
with patch('httpx.get', side_effect=["ok", ConnectError("disconnected")]):
print(crawl_baidu())
with self.assertRaises(ConnectError):
crawl_baidu()
case = ConnectivityTest()
case.test_connectivity()
crawl_baidu 依靠 httpx.get 来爬取数据。我们通过 mock httpx.get 方法,让它有时返回正常结果,有时返回异常。这是通过 side_effect 来实现的。
注意第 14 行,我们使用的是 self.assertRaises,而不是 try-except 来捕捉异常。两者都能够实现检查异常是否抛出的功能。但通过 self.assertRaises,我们强调了这里应该抛出一个异常,它是我们测试逻辑的一部分。而 try-except 则应该用来处理真正的异常。
再强调一遍,“使用 mock 的关键,是要找到引用被 mock 对象的正确方式。”而正确引用的关键,则是这样一句“咒语”
!!! Warning
Mock an item where it is used, not where it came from
在对象被使用的地方进行 mock, 而不是在它出生的地方。
我们通过一个简单的例子来说明这一点:
from os import system
from unittest import mock
import pytest
def echo():
system('echo "Hello"')
with mock.patch('os.system', side_effect=[Exception("patched")]) as mocked:
with pytest.raises(Exception) as e:
echo()
我们在 echo 方法中,调用了系统的 echo 命令。在测试中,我们试图 mock 住 os.system 方法,让它一被调用,就返回一个异常。然后我们通过 pytest 来检查,如果异常抛出,则证明 mock 成功,否则,mock 失败。
但是如果我们运行这个示例,只会得到一个友好的问候,No errors, No warnings! 为什么?
因为当我们在 echo() 函数中调用 system 函数时,此时的 system 存在于__main__名字空间,而不是 os 的名字空间。os 名字空间是 system 出生的地方,而__main__名字空间才是使用它的地方。因此,我们应该 patch 的对象是’main.system’,而不是’os.system’。
现在,让我们将os.system
改为__main__.system
,重新运行,你会发现,魔法又生效了!
在配套代码中,还有一个名为 where_to_patch 的示例,我们也来看一下。
# FOO.PY
def get_name():
return "Alice"
# BAR.PY
from .foo import get_name
class Bar:
def name(self):
return get_name()
# TEST.PY
from unittest.mock import patch
from where_to_patch.bar import Bar
tmp = Bar()
with patch('where_to_patch.foo.get_name', return_value="Bob"):
name = tmp.name()
assert name == "Bob"
测试代码会抛出 AssertionError: assert "Alice" == "Bob"的错误。如果我们把
where_to_patch.foo改为
where_to_patch.bar`,则测试通过。这个稍微扩展了一下的例子,进一步清晰地演示了如何正确引用被 mock 对象。