上下文管理器
上下文管理器(Context Manager)用于自动管理资源的获取和释放,确保代码块执行前后正确处理资源。
0x01. with 语句
基本使用
# 文件操作 - 自动关闭文件
with open('file.txt', 'r') as f:
content = f.read()
# 文件自动关闭,即使发生异常
# 多个上下文管理器
with open('input.txt', 'r') as f_in, open('output.txt', 'w') as f_out:
f_out.write(f_in.read())
# Python 3.10+ 括号形式
with (
open('input.txt', 'r') as f_in,
open('output.txt', 'w') as f_out
):
f_out.write(f_in.read())
0x02. 实现上下文管理器
类实现
class DatabaseConnection:
"""数据库连接上下文管理器"""
def __init__(self, connection_string):
self.connection_string = connection_string
self.connection = None
def __enter__(self):
"""进入上下文时调用"""
print(f'连接到数据库: {self.connection_string}')
self.connection = {'status': 'connected'}
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb):
"""退出上下文时调用
参数:
- exc_type: 异常类型(如果没有异常则为 None)
- exc_val: 异常值
- exc_tb: 异常追踪信息
返回:
- True: 抑制异常
- False: 传播异常
"""
print('关闭数据库连接')
self.connection = None
if exc_type:
print(f'发生异常: {exc_val}')
return False # 不抑制异常
# 使用
with DatabaseConnection('localhost:5432/mydb') as conn:
print(f'连接状态: {conn["status"]}')
# 模拟操作
# raise ValueError('模拟错误') # 测试异常处理
函数实现(contextmanager)
from contextlib import contextmanager
@contextmanager
def timer():
"""计时器上下文管理器"""
import time
start = time.perf_counter()
try:
yield # 这里是 with 块的代码
finally:
end = time.perf_counter()
print(f'耗时: {end - start:.4f} 秒')
# 使用
with timer():
# 模拟耗时操作
sum(range(1000000))
@contextmanager
def temporary_directory():
"""临时目录上下文管理器"""
import tempfile
import shutil
from pathlib import Path
temp_dir = Path(tempfile.mkdtemp())
try:
yield temp_dir
finally:
shutil.rmtree(temp_dir)
# 使用
with temporary_directory() as temp_dir:
# 在临时目录中操作
(temp_dir / 'test.txt').write_text('hello')
print(f'临时目录: {temp_dir}')
# 临时目录自动删除
0x03. 常用上下文管理器
文件操作
# 基本文件操作
with open('file.txt', 'r') as f:
content = f.read()
# 写入文件
with open('file.txt', 'w') as f:
f.write('Hello, World!')
# 追加模式
with open('file.txt', 'a') as f:
f.write('\nNew line')
锁
import threading
# 线程锁
lock = threading.Lock()
with lock:
# 临界区代码
print('获得锁')
# RLock - 可重入锁
rlock = threading.RLock()
with rlock:
with rlock: # 可以多次获取
print('重入锁')
临时文件
import tempfile
# 临时文件
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
f.write('临时内容')
temp_path = f.name
print(f'临时文件路径: {temp_path}')
# 临时目录
with tempfile.TemporaryDirectory() as temp_dir:
print(f'临时目录: {temp_dir}')
# 目录自动删除
数据库连接
import sqlite3
# SQLite 连接
with sqlite3.connect('database.db') as conn:
cursor = conn.cursor()
cursor.execute('SELECT * FROM users')
results = cursor.fetchall()
# 自动提交或回滚
网络连接
import socket
# Socket 连接
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(('example.com', 80))
s.send(b'GET / HTTP/1.1\r\nHost: example.com\r\n\r\n')
response = s.recv(4096)
# 自动关闭连接
0x04. 高级用法
嵌套上下文管理器
from contextlib import contextmanager
@contextmanager
def database_transaction(connection):
"""数据库事务上下文管理器"""
cursor = connection.cursor()
try:
yield cursor
connection.commit()
print('事务提交')
except Exception:
connection.rollback()
print('事务回滚')
raise
finally:
cursor.close()
@contextmanager
def database_connection(db_path):
"""数据库连接上下文管理器"""
import sqlite3
conn = sqlite3.connect(db_path)
try:
yield conn
finally:
conn.close()
# 嵌套使用
with database_connection('test.db') as conn:
with database_transaction(conn) as cursor:
cursor.execute('CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY)')
cursor.execute('INSERT INTO test VALUES (1)')
条件上下文管理器
from contextlib import contextmanager, ExitStack
@contextmanager
def conditional_context(condition, context_manager):
"""条件上下文管理器"""
if condition:
with context_manager as value:
yield value
else:
yield None
# 使用
debug_mode = True
with conditional_context(debug_mode, open('debug.log', 'w')) as log_file:
if log_file:
log_file.write('调试信息')
else:
print('调试模式关闭')
# ExitStack - 动态管理多个上下文
with ExitStack() as stack:
files = [
stack.enter_context(open(f'file{i}.txt', 'w'))
for i in range(3)
]
for i, f in enumerate(files):
f.write(f'文件 {i}')
异常处理上下文
from contextlib import contextmanager, suppress
# suppress - 抑制特定异常
with suppress(FileNotFoundError):
import os
os.remove('nonexistent.txt')
# 不会抛出异常
@contextmanager
def error_handler(error_type, handler):
"""自定义异常处理"""
try:
yield
except error_type as e:
handler(e)
# 使用
def handle_value_error(e):
print(f'处理错误: {e}')
with error_handler(ValueError, handle_value_error):
raise ValueError('测试错误')
0x05. 实际应用
性能监控
from contextlib import contextmanager
import time
import functools
@contextmanager
def performance_monitor(name):
"""性能监控上下文管理器"""
import psutil
import os
process = psutil.Process(os.getpid())
start_time = time.perf_counter()
start_memory = process.memory_info().rss
try:
yield
finally:
end_time = time.perf_counter()
end_memory = process.memory_info().rss
print(f'[{name}] 耗时: {end_time - start_time:.4f} 秒')
print(f'[{name}] 内存变化: {(end_memory - start_memory) / 1024 / 1024:.2f} MB')
# 使用
with performance_monitor('数据处理'):
data = [i ** 2 for i in range(1000000)]
配置管理
from contextlib import contextmanager
import os
@contextmanager
def environment_variable(key, value):
"""临时设置环境变量"""
old_value = os.environ.get(key)
os.environ[key] = value
try:
yield
finally:
if old_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = old_value
@contextmanager
def working_directory(path):
"""临时切换工作目录"""
old_dir = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(old_dir)
# 使用
with environment_variable('DEBUG', 'true'):
print(f'DEBUG={os.environ.get("DEBUG")}')
with working_directory('/tmp'):
print(f'当前目录: {os.getcwd()}')
资源池
from contextlib import contextmanager
from queue import Queue
import threading
class ResourcePool:
"""资源池"""
def __init__(self, max_size=5):
self.pool = Queue(max_size)
self.lock = threading.Lock()
self.created = 0
self.max_size = max_size
def _create_resource(self):
"""创建资源"""
self.created += 1
return f'Resource-{self.created}'
@contextmanager
def acquire(self):
"""获取资源"""
try:
resource = self.pool.get_nowait()
except:
with self.lock:
if self.created < self.max_size:
resource = self._create_resource()
else:
resource = self.pool.get()
try:
yield resource
finally:
self.pool.put(resource)
# 使用
pool = ResourcePool(3)
def worker(pool, worker_id):
with pool.acquire() as resource:
print(f'Worker {worker_id} 使用 {resource}')
threads = [threading.Thread(target=worker, args=(pool, i)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
缓存上下文
from contextlib import contextmanager
from functools import lru_cache
@contextmanager
def cache_enabled(maxsize=128):
"""启用缓存的上下文管理器"""
original_functions = {}
def enable_cache(func):
original_functions[func.__name__] = func
return lru_cache(maxsize=maxsize)(func)
try:
yield enable_cache
finally:
# 清理缓存
for func in original_functions.values():
if hasattr(func, 'cache_clear'):
func.cache_clear()
# 使用
def expensive_calculation(n):
print(f'计算 {n}...')
return sum(range(n))
with cache_enabled() as cache:
cached_calc = cache(expensive_calculation)
print(cached_calc(1000000)) # 第一次计算
print(cached_calc(1000000)) # 使用缓存
参考
目录