单元测试
单元测试是软件开发中确保代码质量的重要实践。Python 提供了 unittest 模块,也有流行的第三方库 pytest。
0x01. unittest 基本使用
编写测试
# calculator.py
def add(a, b):
return a + b
def subtract(a, b):
return a - b
def multiply(a, b):
return a * b
def divide(a, b):
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
# test_calculator.py
import unittest
from calculator import add, subtract, multiply, divide
class TestCalculator(unittest.TestCase):
def test_add(self):
self.assertEqual(add(2, 3), 5)
self.assertEqual(add(-1, 1), 0)
self.assertEqual(add(0, 0), 0)
def test_subtract(self):
self.assertEqual(subtract(5, 3), 2)
self.assertEqual(subtract(3, 5), -2)
def test_multiply(self):
self.assertEqual(multiply(2, 3), 6)
self.assertEqual(multiply(-2, 3), -6)
self.assertEqual(multiply(0, 100), 0)
def test_divide(self):
self.assertEqual(divide(6, 2), 3)
self.assertEqual(divide(7, 2), 3.5)
# 测试异常
with self.assertRaises(ValueError):
divide(1, 0)
if __name__ == '__main__':
unittest.main()
运行测试
# 运行测试文件
python -m unittest test_calculator.py
# 运行特定测试类
python -m unittest test_calculator.TestCalculator
# 运行特定测试方法
python -m unittest test_calculator.TestCalculator.test_add
# 发现并运行所有测试
python -m unittest discover
# 详细输出
python -m unittest -v test_calculator.py
0x02. 断言方法
import unittest
class TestAssertions(unittest.TestCase):
def test_equality(self):
# 相等
self.assertEqual(1 + 1, 2)
self.assertNotEqual(1 + 1, 3)
def test_boolean(self):
# 布尔值
self.assertTrue(True)
self.assertFalse(False)
def test_identity(self):
# 身份
a = [1, 2, 3]
b = a
c = [1, 2, 3]
self.assertIs(a, b) # 同一个对象
self.assertIsNot(a, c) # 不是同一个对象
self.assertIsNone(None)
self.assertIsNotNone(1)
def test_membership(self):
# 成员关系
self.assertIn(1, [1, 2, 3])
self.assertNotIn(4, [1, 2, 3])
def test_comparison(self):
# 比较
self.assertGreater(5, 3)
self.assertGreaterEqual(5, 5)
self.assertLess(3, 5)
self.assertLessEqual(5, 5)
def test_type(self):
# 类型
self.assertIsInstance(1, int)
self.assertNotIsInstance(1, float)
def test_regex(self):
# 正则表达式
self.assertRegex('hello world', r'hello')
self.assertNotRegex('hello', r'world')
def test_almost_equal(self):
# 浮点数近似相等
self.assertAlmostEqual(0.1 + 0.2, 0.3, places=7)
self.assertNotAlmostEqual(0.1 + 0.2, 0.4)
def test_sequence(self):
# 序列
self.assertSequenceEqual([1, 2, 3], [1, 2, 3])
self.assertListEqual([1, 2, 3], [1, 2, 3])
self.assertTupleEqual((1, 2), (1, 2))
self.assertSetEqual({1, 2, 3}, {3, 2, 1})
self.assertDictEqual({'a': 1}, {'a': 1})
def test_exception(self):
# 异常
with self.assertRaises(ValueError):
int('not a number')
with self.assertRaisesRegex(ValueError, 'invalid literal'):
int('not a number')
# 检查异常属性
with self.assertRaises(ValueError) as context:
int('not a number')
self.assertIn('invalid literal', str(context.exception))
0x03. 测试夹具
import unittest
import tempfile
import os
class TestWithFixtures(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""整个测试类开始前执行一次"""
cls.shared_resource = 'setup once'
print('setUpClass')
@classmethod
def tearDownClass(cls):
"""整个测试类结束后执行一次"""
print('tearDownClass')
def setUp(self):
"""每个测试方法前执行"""
self.test_data = [1, 2, 3]
print('setUp')
def tearDown(self):
"""每个测试方法后执行"""
self.test_data = None
print('tearDown')
def test_example1(self):
print('test_example1')
self.assertEqual(len(self.test_data), 3)
def test_example2(self):
print('test_example2')
self.test_data.append(4)
self.assertEqual(len(self.test_data), 4)
class TestFileOperations(unittest.TestCase):
def setUp(self):
# 创建临时文件
self.temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False)
self.temp_file.write('test content')
self.temp_file.close()
def tearDown(self):
# 清理临时文件
os.unlink(self.temp_file.name)
def test_read_file(self):
with open(self.temp_file.name, 'r') as f:
content = f.read()
self.assertEqual(content, 'test content')
0x04. 跳过测试
import unittest
import sys
class TestSkipping(unittest.TestCase):
@unittest.skip('无条件跳过')
def test_skip(self):
self.fail('should not reach here')
@unittest.skipIf(sys.platform == 'win32', 'Windows 上跳过')
def test_skip_if(self):
pass
@unittest.skipUnless(sys.platform == 'linux', '仅在 Linux 上运行')
def test_skip_unless(self):
pass
@unittest.expectedFailure
def test_expected_failure(self):
self.assertEqual(1, 2) # 预期失败
def test_skip_dynamically(self):
if True: # 某些条件
self.skipTest('动态跳过')
0x05. pytest
安装和基本使用
pip install pytest
# test_example.py
def add(a, b):
return a + b
def test_add():
assert add(2, 3) == 5
assert add(-1, 1) == 0
assert add(0, 0) == 0
def test_add_negative():
assert add(-2, -3) == -5
# 运行测试
pytest test_example.py
pytest test_example.py::test_add
pytest -v test_example.py # 详细输出
pytest -s test_example.py # 显示打印
pytest -x test_example.py # 第一个失败时停止
pytest --lf test_example.py # 只运行上次失败的测试
pytest --ff test_example.py # 先运行上次失败的测试
pytest 夹具
import pytest
@pytest.fixture
def sample_data():
return {'name': 'Alice', 'age': 25}
@pytest.fixture
def temp_file(tmp_path):
file = tmp_path / 'test.txt'
file.write_text('test content')
return file
def test_sample_data(sample_data):
assert sample_data['name'] == 'Alice'
assert sample_data['age'] == 25
def test_temp_file(temp_file):
assert temp_file.read_text() == 'test content'
# 夹具作用域
@pytest.fixture(scope='session')
def database_connection():
# 整个测试会话只创建一次
conn = create_connection()
yield conn
conn.close()
@pytest.fixture(scope='class')
def setup_class():
# 每个测试类创建一次
pass
@pytest.fixture(scope='module')
def setup_module():
# 每个测试模块创建一次
pass
@pytest.fixture(scope='function')
def setup_function():
# 每个测试函数创建一次(默认)
pass
pytest 参数化
import pytest
@pytest.mark.parametrize('a, b, expected', [
(2, 3, 5),
(-1, 1, 0),
(0, 0, 0),
(100, 200, 300),
])
def test_add(a, b, expected):
assert a + b == expected
# 多个参数化装饰器
@pytest.mark.parametrize('x', [0, 1])
@pytest.mark.parametrize('y', [2, 3])
def test_multiply(x, y):
assert x * y == x * y
# 参数化 ID
@pytest.mark.parametrize('input, expected', [
('hello', 5),
('world', 5),
pytest.param('', 0, id='empty'),
pytest.param('a' * 100, 100, id='long'),
])
def test_len(input, expected):
assert len(input) == expected
pytest 标记
import pytest
@pytest.mark.slow
def test_slow_operation():
# 标记为慢速测试
pass
@pytest.mark.skip(reason='尚未实现')
def test_not_implemented():
pass
@pytest.mark.skipif(sys.platform == 'win32', reason='Windows 不支持')
def test_linux_only():
pass
@pytest.mark.xfail(reason='已知问题')
def test_known_bug():
pass
# 自定义标记
@pytest.mark.database
def test_database():
pass
# 运行特定标记
# pytest -m slow
# pytest -m "not slow"
# pytest -m "database and not slow"
0x06. Mock
from unittest.mock import Mock, patch, MagicMock
# 创建 Mock 对象
mock = Mock()
mock.method.return_value = 42
assert mock.method() == 42
# 检查调用
mock.method('arg1', 'arg2')
mock.method.assert_called()
mock.method.assert_called_once()
mock.method.assert_called_with('arg1', 'arg2')
# 使用 patch 装饰器
@patch('module.function')
def test_with_mock(mock_function):
mock_function.return_value = 'mocked'
result = module.function()
assert result == 'mocked'
# 使用 patch 上下文管理器
def test_with_context():
with patch('module.function') as mock_function:
mock_function.return_value = 'mocked'
result = module.function()
assert result == 'mocked'
# Mock 类方法
class MyClass:
def method(self):
return 'original'
def test_mock_method():
obj = MyClass()
obj.method = Mock(return_value='mocked')
assert obj.method() == 'mocked'
# MagicMock - 自动创建属性和方法
mock = MagicMock()
mock.attr.method().other.return_value = 42
assert mock.attr.method().other() == 42
实际应用
import requests
from unittest.mock import patch, Mock
def fetch_user(user_id):
response = requests.get(f'https://api.example.com/users/{user_id}')
return response.json()
@patch('requests.get')
def test_fetch_user(mock_get):
# 设置 mock 返回值
mock_response = Mock()
mock_response.json.return_value = {'id': 1, 'name': 'Alice'}
mock_get.return_value = mock_response
# 调用被测试函数
user = fetch_user(1)
# 验证结果
assert user['name'] == 'Alice'
# 验证调用
mock_get.assert_called_once_with('https://api.example.com/users/1')
0x07. 测试最佳实践
测试目录结构
project/
├── src/
│ ├── __init__.py
│ ├── calculator.py
│ └── utils.py
├── tests/
│ ├── __init__.py
│ ├── test_calculator.py
│ └── test_utils.py
└── pyproject.toml
pytest 配置
# pyproject.toml
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
markers = [
"slow: marks tests as slow",
"database: marks tests that need database",
]
测试原则
"""
1. 测试应该是独立的
2. 测试应该是可重复的
3. 测试应该是快速的
4. 测试应该有清晰的命名
5. 每个测试只测试一件事
6. 使用夹具减少重复
7. 测试边界条件和异常情况
"""
# 好的测试命名
def test_user_creation_with_valid_data():
pass
def test_user_creation_with_invalid_email_raises_error():
pass
def test_user_age_must_be_positive():
pass
# 好的测试结构 (AAA 模式)
def test_example():
# Arrange - 准备测试数据
data = [1, 2, 3]
# Act - 执行被测试操作
result = sum(data)
# Assert - 验证结果
assert result == 6
参考
目录