分布式数据库架构设计与实践:从分片策略到一致性保证的完整方案
引言
随着数据量的爆炸式增长和业务复杂度的不断提升,传统的单机数据库已经无法满足现代应用的需求。分布式数据库作为解决大规模数据存储和处理的关键技术,在保证高可用性、可扩展性和一致性方面发挥着重要作用。本文将深入探讨分布式数据库的架构设计原理和实践方案。
分布式数据库架构概述
1. 核心架构组件
graph TB
subgraph "客户端层"
A[应用程序] --> B[数据库代理]
B --> C[连接池管理器]
end
subgraph "路由层"
C --> D[分片路由器]
D --> E[负载均衡器]
E --> F[查询优化器]
end
subgraph "数据层"
F --> G[分片1]
F --> H[分片2]
F --> I[分片3]
F --> J[分片N]
end
subgraph "元数据层"
K[配置中心] --> D
L[分片映射表] --> D
M[节点状态监控] --> E
end
subgraph "一致性层"
N[分布式锁] --> O[事务协调器]
O --> P[两阶段提交]
P --> Q[Raft共识算法]
end
2. 分布式数据库配置
# config/distributed_database.yaml
distributed_database:
cluster_name: "production_cluster"
# 节点配置
nodes:
- id: "node1"
host: "192.168.1.10"
port: 3306
role: "primary"
datacenter: "dc1"
rack: "rack1"
weight: 100
- id: "node2"
host: "192.168.1.11"
port: 3306
role: "secondary"
datacenter: "dc1"
rack: "rack2"
weight: 100
- id: "node3"
host: "192.168.1.12"
port: 3306
role: "secondary"
datacenter: "dc2"
rack: "rack1"
weight: 100
# 分片配置
sharding:
strategy: "hash" # hash, range, directory
shard_count: 16
replication_factor: 3
# 分片键配置
shard_keys:
users: ["user_id"]
orders: ["user_id", "order_date"]
products: ["category_id"]
# 分片映射
shard_mapping:
shard_0: ["node1", "node2"]
shard_1: ["node2", "node3"]
shard_2: ["node3", "node1"]
# ... 其他分片映射
# 一致性配置
consistency:
level: "eventual" # strong, eventual, weak
read_preference: "primary_preferred"
write_concern: "majority"
# 事务配置
transaction:
isolation_level: "read_committed"
timeout: 30000 # 毫秒
retry_count: 3
# 故障恢复配置
failover:
detection_timeout: 5000 # 毫秒
election_timeout: 10000 # 毫秒
heartbeat_interval: 1000 # 毫秒
# 性能配置
performance:
connection_pool_size: 100
query_timeout: 30000
batch_size: 1000
cache_size: "1GB"
数据分片策略实现
1. 分片路由器
#!/usr/bin/env python3
# src/sharding/shard_router.py
import hashlib
import bisect
import logging
from typing import Dict, List, Any, Optional, Tuple
from enum import Enum
from dataclasses import dataclass
import yaml
import threading
import time
class ShardingStrategy(Enum):
HASH = "hash"
RANGE = "range"
DIRECTORY = "directory"
CONSISTENT_HASH = "consistent_hash"
@dataclass
class ShardInfo:
shard_id: str
nodes: List[str]
range_start: Optional[Any] = None
range_end: Optional[Any] = None
weight: int = 100
status: str = "active"
@dataclass
class NodeInfo:
node_id: str
host: str
port: int
role: str
datacenter: str
rack: str
weight: int
status: str = "active"
class ShardRouter:
def __init__(self, config_file: str):
self.config = self._load_config(config_file)
self.logger = self._setup_logging()
# 初始化分片信息
self.shards: Dict[str, ShardInfo] = {}
self.nodes: Dict[str, NodeInfo] = {}
self.shard_keys: Dict[str, List[str]] = {}
# 一致性哈希环(用于consistent_hash策略)
self.hash_ring: List[Tuple[int, str]] = []
self.virtual_nodes = 150 # 每个物理节点的虚拟节点数
# 线程锁
self.lock = threading.RLock()
# 初始化路由器
self._initialize_router()
def _load_config(self, config_file: str) -> Dict[str, Any]:
"""加载配置文件"""
with open(config_file, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def _setup_logging(self) -> logging.Logger:
"""设置日志记录"""
logger = logging.getLogger('ShardRouter')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _initialize_router(self):
"""初始化路由器"""
# 加载节点信息
for node_config in self.config['distributed_database']['nodes']:
node = NodeInfo(
node_id=node_config['id'],
host=node_config['host'],
port=node_config['port'],
role=node_config['role'],
datacenter=node_config['datacenter'],
rack=node_config['rack'],
weight=node_config['weight']
)
self.nodes[node.node_id] = node
# 加载分片键配置
self.shard_keys = self.config['distributed_database']['sharding']['shard_keys']
# 根据分片策略初始化分片
strategy = ShardingStrategy(self.config['distributed_database']['sharding']['strategy'])
if strategy == ShardingStrategy.HASH:
self._initialize_hash_sharding()
elif strategy == ShardingStrategy.RANGE:
self._initialize_range_sharding()
elif strategy == ShardingStrategy.CONSISTENT_HASH:
self._initialize_consistent_hash_sharding()
elif strategy == ShardingStrategy.DIRECTORY:
self._initialize_directory_sharding()
self.logger.info(f"分片路由器初始化完成,策略: {strategy.value}")
def _initialize_hash_sharding(self):
"""初始化哈希分片"""
shard_count = self.config['distributed_database']['sharding']['shard_count']
shard_mapping = self.config['distributed_database']['sharding']['shard_mapping']
for i in range(shard_count):
shard_id = f"shard_{i}"
nodes = shard_mapping.get(shard_id, [])
self.shards[shard_id] = ShardInfo(
shard_id=shard_id,
nodes=nodes
)
def _initialize_range_sharding(self):
"""初始化范围分片"""
# 这里可以根据具体需求配置范围分片
# 示例:按用户ID范围分片
ranges = [
(0, 10000),
(10000, 20000),
(20000, 30000),
(30000, float('inf'))
]
for i, (start, end) in enumerate(ranges):
shard_id = f"shard_{i}"
self.shards[shard_id] = ShardInfo(
shard_id=shard_id,
nodes=[f"node{(i % len(self.nodes)) + 1}"],
range_start=start,
range_end=end
)
def _initialize_consistent_hash_sharding(self):
"""初始化一致性哈希分片"""
self.hash_ring = []
for node_id, node in self.nodes.items():
# 为每个物理节点创建多个虚拟节点
for i in range(self.virtual_nodes):
virtual_node_key = f"{node_id}:{i}"
hash_value = self._hash_function(virtual_node_key)
self.hash_ring.append((hash_value, node_id))
# 按哈希值排序
self.hash_ring.sort()
self.logger.info(f"一致性哈希环初始化完成,虚拟节点数: {len(self.hash_ring)}")
def _initialize_directory_sharding(self):
"""初始化目录分片"""
# 目录分片需要维护一个分片目录表
# 这里简化实现,实际应该从配置或数据库加载
directory_mapping = {
'users_1': 'shard_0',
'users_2': 'shard_1',
'orders_2023': 'shard_2',
'orders_2024': 'shard_3'
}
for table_partition, shard_id in directory_mapping.items():
if shard_id not in self.shards:
self.shards[shard_id] = ShardInfo(
shard_id=shard_id,
nodes=[f"node{(len(self.shards) % len(self.nodes)) + 1}"]
)
def _hash_function(self, key: str) -> int:
"""哈希函数"""
return int(hashlib.md5(key.encode()).hexdigest(), 16)
def route_query(self, table: str, query_params: Dict[str, Any]) -> List[str]:
"""路由查询到相应的分片"""
with self.lock:
strategy = ShardingStrategy(self.config['distributed_database']['sharding']['strategy'])
if strategy == ShardingStrategy.HASH:
return self._route_hash_query(table, query_params)
elif strategy == ShardingStrategy.RANGE:
return self._route_range_query(table, query_params)
elif strategy == ShardingStrategy.CONSISTENT_HASH:
return self._route_consistent_hash_query(table, query_params)
elif strategy == ShardingStrategy.DIRECTORY:
return self._route_directory_query(table, query_params)
return []
def _route_hash_query(self, table: str, query_params: Dict[str, Any]) -> List[str]:
"""哈希分片路由"""
if table not in self.shard_keys:
# 如果没有配置分片键,广播到所有分片
return list(self.shards.keys())
shard_key_columns = self.shard_keys[table]
# 检查是否包含所有分片键
if not all(col in query_params for col in shard_key_columns):
# 如果缺少分片键,广播到所有分片
return list(self.shards.keys())
# 构建分片键值
shard_key_value = "|".join(str(query_params[col]) for col in shard_key_columns)
# 计算哈希值
hash_value = self._hash_function(shard_key_value)
shard_count = self.config['distributed_database']['sharding']['shard_count']
shard_index = hash_value % shard_count
shard_id = f"shard_{shard_index}"
return [shard_id] if shard_id in self.shards else []
def _route_range_query(self, table: str, query_params: Dict[str, Any]) -> List[str]:
"""范围分片路由"""
if table not in self.shard_keys:
return list(self.shards.keys())
shard_key_columns = self.shard_keys[table]
primary_key = shard_key_columns[0] # 使用第一个分片键作为范围键
if primary_key not in query_params:
return list(self.shards.keys())
key_value = query_params[primary_key]
target_shards = []
for shard_id, shard in self.shards.items():
if (shard.range_start is None or key_value >= shard.range_start) and \
(shard.range_end is None or key_value < shard.range_end):
target_shards.append(shard_id)
return target_shards
def _route_consistent_hash_query(self, table: str, query_params: Dict[str, Any]) -> List[str]:
"""一致性哈希分片路由"""
if table not in self.shard_keys:
return list(set(node_id for _, node_id in self.hash_ring))
shard_key_columns = self.shard_keys[table]
if not all(col in query_params for col in shard_key_columns):
return list(set(node_id for _, node_id in self.hash_ring))
# 构建分片键值
shard_key_value = "|".join(str(query_params[col]) for col in shard_key_columns)
hash_value = self._hash_function(shard_key_value)
# 在哈希环中查找目标节点
index = bisect.bisect_right([h for h, _ in self.hash_ring], hash_value)
if index == len(self.hash_ring):
index = 0
target_node = self.hash_ring[index][1]
return [target_node]
def _route_directory_query(self, table: str, query_params: Dict[str, Any]) -> List[str]:
"""目录分片路由"""
# 简化实现:根据表名和分区信息查找分片
# 实际实现需要维护完整的目录映射表
# 示例:根据时间分区
if 'partition_key' in query_params:
partition_key = query_params['partition_key']
# 查找对应的分片
for shard_id in self.shards:
if partition_key in shard_id:
return [shard_id]
return list(self.shards.keys())
def get_shard_nodes(self, shard_id: str) -> List[NodeInfo]:
"""获取分片对应的节点信息"""
if shard_id not in self.shards:
return []
shard = self.shards[shard_id]
return [self.nodes[node_id] for node_id in shard.nodes if node_id in self.nodes]
def add_shard(self, shard_id: str, nodes: List[str]):
"""添加新分片"""
with self.lock:
self.shards[shard_id] = ShardInfo(
shard_id=shard_id,
nodes=nodes
)
self.logger.info(f"添加新分片: {shard_id}")
def remove_shard(self, shard_id: str):
"""移除分片"""
with self.lock:
if shard_id in self.shards:
del self.shards[shard_id]
self.logger.info(f"移除分片: {shard_id}")
def update_node_status(self, node_id: str, status: str):
"""更新节点状态"""
with self.lock:
if node_id in self.nodes:
self.nodes[node_id].status = status
self.logger.info(f"更新节点状态: {node_id} -> {status}")
def get_cluster_topology(self) -> Dict[str, Any]:
"""获取集群拓扑信息"""
return {
'nodes': {node_id: {
'host': node.host,
'port': node.port,
'role': node.role,
'datacenter': node.datacenter,
'rack': node.rack,
'status': node.status
} for node_id, node in self.nodes.items()},
'shards': {shard_id: {
'nodes': shard.nodes,
'status': shard.status,
'range_start': shard.range_start,
'range_end': shard.range_end
} for shard_id, shard in self.shards.items()}
}
def main():
# 示例用法
router = ShardRouter('config/distributed_database.yaml')
# 测试查询路由
test_queries = [
{
'table': 'users',
'params': {'user_id': 12345}
},
{
'table': 'orders',
'params': {'user_id': 12345, 'order_date': '2024-01-15'}
},
{
'table': 'products',
'params': {'category_id': 'electronics'}
}
]
for query in test_queries:
shards = router.route_query(query['table'], query['params'])
print(f"查询 {query['table']} 路由到分片: {shards}")
for shard_id in shards:
nodes = router.get_shard_nodes(shard_id)
print(f" 分片 {shard_id} 节点: {[node.node_id for node in nodes]}")
# 打印集群拓扑
topology = router.get_cluster_topology()
print("\n集群拓扑:")
print(yaml.dump(topology, default_flow_style=False, allow_unicode=True))
if __name__ == "__main__":
main()
2. 分布式事务协调器
#!/usr/bin/env python3
# src/transaction/distributed_transaction.py
import uuid
import time
import logging
import threading
from typing import Dict, List, Any, Optional, Callable
from enum import Enum
from dataclasses import dataclass, field
import json
from concurrent.futures import ThreadPoolExecutor, Future
class TransactionState(Enum):
ACTIVE = "active"
PREPARING = "preparing"
PREPARED = "prepared"
COMMITTING = "committing"
COMMITTED = "committed"
ABORTING = "aborting"
ABORTED = "aborted"
class ParticipantState(Enum):
ACTIVE = "active"
PREPARED = "prepared"
COMMITTED = "committed"
ABORTED = "aborted"
UNKNOWN = "unknown"
@dataclass
class TransactionParticipant:
participant_id: str
node_id: str
shard_id: str
operations: List[Dict[str, Any]] = field(default_factory=list)
state: ParticipantState = ParticipantState.ACTIVE
prepare_timestamp: Optional[float] = None
commit_timestamp: Optional[float] = None
@dataclass
class DistributedTransaction:
transaction_id: str
coordinator_id: str
participants: Dict[str, TransactionParticipant] = field(default_factory=dict)
state: TransactionState = TransactionState.ACTIVE
start_timestamp: float = field(default_factory=time.time)
timeout: float = 30.0 # 30秒超时
isolation_level: str = "read_committed"
def is_expired(self) -> bool:
return time.time() - self.start_timestamp > self.timeout
class DistributedTransactionCoordinator:
def __init__(self, coordinator_id: str, shard_router):
self.coordinator_id = coordinator_id
self.shard_router = shard_router
self.logger = self._setup_logging()
# 活跃事务
self.active_transactions: Dict[str, DistributedTransaction] = {}
# 线程池
self.executor = ThreadPoolExecutor(max_workers=50)
# 锁
self.lock = threading.RLock()
# 启动清理线程
self.cleanup_thread = threading.Thread(target=self._cleanup_expired_transactions)
self.cleanup_thread.daemon = True
self.cleanup_thread.start()
def _setup_logging(self) -> logging.Logger:
"""设置日志记录"""
logger = logging.getLogger(f'TransactionCoordinator-{self.coordinator_id}')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def begin_transaction(self, isolation_level: str = "read_committed",
timeout: float = 30.0) -> str:
"""开始分布式事务"""
transaction_id = str(uuid.uuid4())
with self.lock:
transaction = DistributedTransaction(
transaction_id=transaction_id,
coordinator_id=self.coordinator_id,
isolation_level=isolation_level,
timeout=timeout
)
self.active_transactions[transaction_id] = transaction
self.logger.info(f"开始分布式事务: {transaction_id}")
return transaction_id
def add_operation(self, transaction_id: str, table: str,
operation: str, data: Dict[str, Any]) -> bool:
"""添加事务操作"""
with self.lock:
if transaction_id not in self.active_transactions:
self.logger.error(f"事务不存在: {transaction_id}")
return False
transaction = self.active_transactions[transaction_id]
if transaction.state != TransactionState.ACTIVE:
self.logger.error(f"事务状态无效: {transaction_id}, 状态: {transaction.state}")
return False
if transaction.is_expired():
self.logger.error(f"事务已超时: {transaction_id}")
self._abort_transaction(transaction_id)
return False
# 路由操作到相应的分片
shards = self.shard_router.route_query(table, data)
for shard_id in shards:
nodes = self.shard_router.get_shard_nodes(shard_id)
if not nodes:
continue
# 选择主节点
primary_node = next((node for node in nodes if node.role == 'primary'), nodes[0])
participant_id = f"{shard_id}_{primary_node.node_id}"
with self.lock:
if participant_id not in transaction.participants:
transaction.participants[participant_id] = TransactionParticipant(
participant_id=participant_id,
node_id=primary_node.node_id,
shard_id=shard_id
)
transaction.participants[participant_id].operations.append({
'table': table,
'operation': operation,
'data': data,
'timestamp': time.time()
})
self.logger.info(f"添加事务操作: {transaction_id}, 表: {table}, 操作: {operation}")
return True
def commit_transaction(self, transaction_id: str) -> bool:
"""提交分布式事务(两阶段提交)"""
with self.lock:
if transaction_id not in self.active_transactions:
self.logger.error(f"事务不存在: {transaction_id}")
return False
transaction = self.active_transactions[transaction_id]
if transaction.state != TransactionState.ACTIVE:
self.logger.error(f"事务状态无效: {transaction_id}")
return False
if transaction.is_expired():
self.logger.error(f"事务已超时: {transaction_id}")
self._abort_transaction(transaction_id)
return False
transaction.state = TransactionState.PREPARING
self.logger.info(f"开始提交事务: {transaction_id}")
# 第一阶段:准备阶段
if not self._prepare_phase(transaction):
self.logger.error(f"准备阶段失败,中止事务: {transaction_id}")
self._abort_transaction(transaction_id)
return False
# 第二阶段:提交阶段
if not self._commit_phase(transaction):
self.logger.error(f"提交阶段失败: {transaction_id}")
return False
with self.lock:
transaction.state = TransactionState.COMMITTED
del self.active_transactions[transaction_id]
self.logger.info(f"事务提交成功: {transaction_id}")
return True
def _prepare_phase(self, transaction: DistributedTransaction) -> bool:
"""两阶段提交的准备阶段"""
self.logger.info(f"执行准备阶段: {transaction.transaction_id}")
# 并行向所有参与者发送准备请求
prepare_futures = []
for participant in transaction.participants.values():
future = self.executor.submit(self._prepare_participant, participant)
prepare_futures.append((participant.participant_id, future))
# 等待所有参与者响应
all_prepared = True
for participant_id, future in prepare_futures:
try:
result = future.result(timeout=10.0) # 10秒超时
if not result:
self.logger.error(f"参与者准备失败: {participant_id}")
all_prepared = False
else:
transaction.participants[participant_id].state = ParticipantState.PREPARED
transaction.participants[participant_id].prepare_timestamp = time.time()
except Exception as e:
self.logger.error(f"参与者准备异常: {participant_id}, 错误: {e}")
all_prepared = False
if all_prepared:
transaction.state = TransactionState.PREPARED
self.logger.info(f"所有参与者准备完成: {transaction.transaction_id}")
return all_prepared
def _commit_phase(self, transaction: DistributedTransaction) -> bool:
"""两阶段提交的提交阶段"""
self.logger.info(f"执行提交阶段: {transaction.transaction_id}")
transaction.state = TransactionState.COMMITTING
# 并行向所有参与者发送提交请求
commit_futures = []
for participant in transaction.participants.values():
future = self.executor.submit(self._commit_participant, participant)
commit_futures.append((participant.participant_id, future))
# 等待所有参与者响应
all_committed = True
for participant_id, future in commit_futures:
try:
result = future.result(timeout=10.0) # 10秒超时
if result:
transaction.participants[participant_id].state = ParticipantState.COMMITTED
transaction.participants[participant_id].commit_timestamp = time.time()
else:
self.logger.error(f"参与者提交失败: {participant_id}")
all_committed = False
except Exception as e:
self.logger.error(f"参与者提交异常: {participant_id}, 错误: {e}")
all_committed = False
return all_committed
def _prepare_participant(self, participant: TransactionParticipant) -> bool:
"""向参与者发送准备请求"""
try:
# 这里应该实际调用参与者节点的准备接口
# 简化实现,模拟网络调用
self.logger.info(f"向参与者发送准备请求: {participant.participant_id}")
# 模拟网络延迟
time.sleep(0.1)
# 模拟准备操作
# 实际实现中,这里会:
# 1. 锁定相关资源
# 2. 验证操作的有效性
# 3. 准备回滚日志
# 4. 返回准备结果
# 简化:90%的概率成功
import random
return random.random() > 0.1
except Exception as e:
self.logger.error(f"准备参与者失败: {participant.participant_id}, 错误: {e}")
return False
def _commit_participant(self, participant: TransactionParticipant) -> bool:
"""向参与者发送提交请求"""
try:
self.logger.info(f"向参与者发送提交请求: {participant.participant_id}")
# 模拟网络延迟
time.sleep(0.1)
# 模拟提交操作
# 实际实现中,这里会:
# 1. 执行实际的数据修改
# 2. 释放锁定的资源
# 3. 清理事务日志
# 4. 返回提交结果
# 简化:95%的概率成功
import random
return random.random() > 0.05
except Exception as e:
self.logger.error(f"提交参与者失败: {participant.participant_id}, 错误: {e}")
return False
def _abort_transaction(self, transaction_id: str):
"""中止事务"""
with self.lock:
if transaction_id not in self.active_transactions:
return
transaction = self.active_transactions[transaction_id]
transaction.state = TransactionState.ABORTING
self.logger.info(f"中止事务: {transaction_id}")
# 并行向所有参与者发送中止请求
abort_futures = []
for participant in transaction.participants.values():
future = self.executor.submit(self._abort_participant, participant)
abort_futures.append(future)
# 等待所有中止操作完成
for future in abort_futures:
try:
future.result(timeout=5.0)
except Exception as e:
self.logger.error(f"中止参与者异常: {e}")
with self.lock:
transaction.state = TransactionState.ABORTED
del self.active_transactions[transaction_id]
self.logger.info(f"事务中止完成: {transaction_id}")
def _abort_participant(self, participant: TransactionParticipant) -> bool:
"""向参与者发送中止请求"""
try:
self.logger.info(f"向参与者发送中止请求: {participant.participant_id}")
# 模拟网络延迟
time.sleep(0.05)
# 模拟中止操作
# 实际实现中,这里会:
# 1. 回滚所有修改
# 2. 释放锁定的资源
# 3. 清理事务日志
participant.state = ParticipantState.ABORTED
return True
except Exception as e:
self.logger.error(f"中止参与者失败: {participant.participant_id}, 错误: {e}")
return False
def rollback_transaction(self, transaction_id: str) -> bool:
"""回滚事务"""
self.logger.info(f"回滚事务: {transaction_id}")
self._abort_transaction(transaction_id)
return True
def get_transaction_status(self, transaction_id: str) -> Optional[Dict[str, Any]]:
"""获取事务状态"""
with self.lock:
if transaction_id not in self.active_transactions:
return None
transaction = self.active_transactions[transaction_id]
return {
'transaction_id': transaction.transaction_id,
'state': transaction.state.value,
'start_timestamp': transaction.start_timestamp,
'timeout': transaction.timeout,
'participants': {
p_id: {
'node_id': p.node_id,
'shard_id': p.shard_id,
'state': p.state.value,
'operations_count': len(p.operations)
}
for p_id, p in transaction.participants.items()
}
}
def _cleanup_expired_transactions(self):
"""清理过期事务"""
while True:
try:
time.sleep(10) # 每10秒检查一次
expired_transactions = []
with self.lock:
for transaction_id, transaction in self.active_transactions.items():
if transaction.is_expired():
expired_transactions.append(transaction_id)
for transaction_id in expired_transactions:
self.logger.warning(f"清理过期事务: {transaction_id}")
self._abort_transaction(transaction_id)
except Exception as e:
self.logger.error(f"清理过期事务异常: {e}")
def main():
# 示例用法
from sharding.shard_router import ShardRouter
# 创建分片路由器
router = ShardRouter('config/distributed_database.yaml')
# 创建事务协调器
coordinator = DistributedTransactionCoordinator('coordinator_1', router)
# 开始事务
tx_id = coordinator.begin_transaction()
print(f"开始事务: {tx_id}")
# 添加操作
coordinator.add_operation(tx_id, 'users', 'insert', {'user_id': 12345, 'name': 'John'})
coordinator.add_operation(tx_id, 'orders', 'insert', {
'user_id': 12345,
'order_id': 67890,
'order_date': '2024-01-15'
})
# 获取事务状态
status = coordinator.get_transaction_status(tx_id)
print(f"事务状态: {json.dumps(status, indent=2, ensure_ascii=False)}")
# 提交事务
success = coordinator.commit_transaction(tx_id)
print(f"事务提交结果: {success}")
if __name__ == "__main__":
main()
一致性协议实现
1. Raft共识算法
#!/usr/bin/env python3
# src/consensus/raft_consensus.py
import time
import random
import threading
import logging
import json
from typing import Dict, List, Any, Optional, Callable
from enum import Enum
from dataclasses import dataclass, field
import socket
import pickle
class NodeState(Enum):
FOLLOWER = "follower"
CANDIDATE = "candidate"
LEADER = "leader"
@dataclass
class LogEntry:
term: int
index: int
command: Dict[str, Any]
timestamp: float = field(default_factory=time.time)
@dataclass
class VoteRequest:
term: int
candidate_id: str
last_log_index: int
last_log_term: int
@dataclass
class VoteResponse:
term: int
vote_granted: bool
@dataclass
class AppendEntriesRequest:
term: int
leader_id: str
prev_log_index: int
prev_log_term: int
entries: List[LogEntry]
leader_commit: int
@dataclass
class AppendEntriesResponse:
term: int
success: bool
match_index: int = 0
class RaftNode:
def __init__(self, node_id: str, cluster_nodes: List[str],
host: str = "localhost", port: int = 8000):
self.node_id = node_id
self.cluster_nodes = cluster_nodes
self.host = host
self.port = port
# Raft状态
self.current_term = 0
self.voted_for: Optional[str] = None
self.log: List[LogEntry] = []
self.state = NodeState.FOLLOWER
# 易失性状态
self.commit_index = 0
self.last_applied = 0
# Leader状态
self.next_index: Dict[str, int] = {}
self.match_index: Dict[str, int] = {}
# 定时器
self.election_timeout = self._random_election_timeout()
self.last_heartbeat = time.time()
self.heartbeat_interval = 0.1 # 100ms
# 线程和锁
self.lock = threading.RLock()
self.running = False
self.election_thread: Optional[threading.Thread] = None
self.heartbeat_thread: Optional[threading.Thread] = None
# 日志记录
self.logger = self._setup_logging()
# 状态机
self.state_machine: Dict[str, Any] = {}
self.state_machine_callbacks: List[Callable] = []
def _setup_logging(self) -> logging.Logger:
"""设置日志记录"""
logger = logging.getLogger(f'RaftNode-{self.node_id}')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
f'%(asctime)s - {self.node_id} - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _random_election_timeout(self) -> float:
"""生成随机选举超时时间(150-300ms)"""
return random.uniform(0.15, 0.3)
def start(self):
"""启动Raft节点"""
with self.lock:
if self.running:
return
self.running = True
self.last_heartbeat = time.time()
# 启动选举定时器线程
self.election_thread = threading.Thread(target=self._election_timer)
self.election_thread.daemon = True
self.election_thread.start()
self.logger.info(f"Raft节点启动: {self.node_id}")
def stop(self):
"""停止Raft节点"""
with self.lock:
self.running = False
if self.heartbeat_thread:
self.heartbeat_thread.join(timeout=1.0)
self.logger.info(f"Raft节点停止: {self.node_id}")
def _election_timer(self):
"""选举定时器"""
while self.running:
time.sleep(0.01) # 10ms检查间隔
with self.lock:
if self.state == NodeState.LEADER:
continue
# 检查是否超时
if time.time() - self.last_heartbeat > self.election_timeout:
self._start_election()
def _start_election(self):
"""开始选举"""
self.logger.info(f"开始选举,任期: {self.current_term + 1}")
# 转换为候选人状态
self.state = NodeState.CANDIDATE
self.current_term += 1
self.voted_for = self.node_id
self.last_heartbeat = time.time()
self.election_timeout = self._random_election_timeout()
# 获取最后日志信息
last_log_index = len(self.log) - 1
last_log_term = self.log[last_log_index].term if self.log else 0
# 向所有其他节点发送投票请求
vote_request = VoteRequest(
term=self.current_term,
candidate_id=self.node_id,
last_log_index=last_log_index,
last_log_term=last_log_term
)
votes_received = 1 # 自己的票
for node_id in self.cluster_nodes:
if node_id == self.node_id:
continue
try:
response = self._send_vote_request(node_id, vote_request)
if response and response.vote_granted:
votes_received += 1
elif response and response.term > self.current_term:
# 发现更高的任期,转为跟随者
self._become_follower(response.term)
return
except Exception as e:
self.logger.error(f"发送投票请求失败: {node_id}, 错误: {e}")
# 检查是否获得多数票
majority = len(self.cluster_nodes) // 2 + 1
if votes_received >= majority:
self._become_leader()
else:
self._become_follower(self.current_term)
def _become_leader(self):
"""成为Leader"""
self.logger.info(f"成为Leader,任期: {self.current_term}")
self.state = NodeState.LEADER
# 初始化Leader状态
for node_id in self.cluster_nodes:
if node_id != self.node_id:
self.next_index[node_id] = len(self.log)
self.match_index[node_id] = 0
# 启动心跳线程
if self.heartbeat_thread:
self.heartbeat_thread.join(timeout=0.1)
self.heartbeat_thread = threading.Thread(target=self._send_heartbeats)
self.heartbeat_thread.daemon = True
self.heartbeat_thread.start()
# 发送空的AppendEntries作为心跳
self._send_append_entries()
def _become_follower(self, term: int):
"""成为Follower"""
if term > self.current_term:
self.current_term = term
self.voted_for = None
self.state = NodeState.FOLLOWER
self.last_heartbeat = time.time()
# 停止心跳线程
if self.heartbeat_thread:
self.heartbeat_thread = None
def _send_heartbeats(self):
"""发送心跳"""
while self.running and self.state == NodeState.LEADER:
self._send_append_entries()
time.sleep(self.heartbeat_interval)
def _send_append_entries(self):
"""发送AppendEntries RPC"""
for node_id in self.cluster_nodes:
if node_id == self.node_id:
continue
try:
# 准备AppendEntries请求
next_index = self.next_index.get(node_id, len(self.log))
prev_log_index = next_index - 1
prev_log_term = 0
if prev_log_index >= 0 and prev_log_index < len(self.log):
prev_log_term = self.log[prev_log_index].term
# 发送的日志条目
entries = self.log[next_index:] if next_index < len(self.log) else []
request = AppendEntriesRequest(
term=self.current_term,
leader_id=self.node_id,
prev_log_index=prev_log_index,
prev_log_term=prev_log_term,
entries=entries,
leader_commit=self.commit_index
)
response = self._send_append_entries_request(node_id, request)
if response:
if response.term > self.current_term:
# 发现更高的任期,转为跟随者
self._become_follower(response.term)
return
if response.success:
# 更新next_index和match_index
self.match_index[node_id] = prev_log_index + len(entries)
self.next_index[node_id] = self.match_index[node_id] + 1
else:
# 减少next_index并重试
self.next_index[node_id] = max(0, self.next_index[node_id] - 1)
except Exception as e:
self.logger.error(f"发送AppendEntries失败: {node_id}, 错误: {e}")
# 更新commit_index
self._update_commit_index()
def _update_commit_index(self):
"""更新commit_index"""
if self.state != NodeState.LEADER:
return
# 找到大多数节点都已复制的最大索引
for i in range(len(self.log) - 1, self.commit_index, -1):
if self.log[i].term == self.current_term:
count = 1 # Leader自己
for node_id in self.cluster_nodes:
if node_id != self.node_id and self.match_index.get(node_id, 0) >= i:
count += 1
majority = len(self.cluster_nodes) // 2 + 1
if count >= majority:
self.commit_index = i
self._apply_committed_entries()
break
def _apply_committed_entries(self):
"""应用已提交的日志条目到状态机"""
while self.last_applied < self.commit_index:
self.last_applied += 1
entry = self.log[self.last_applied]
# 应用到状态机
self._apply_to_state_machine(entry.command)
self.logger.info(f"应用日志条目: {self.last_applied}, 命令: {entry.command}")
def _apply_to_state_machine(self, command: Dict[str, Any]):
"""应用命令到状态机"""
operation = command.get('operation')
key = command.get('key')
value = command.get('value')
if operation == 'set':
self.state_machine[key] = value
elif operation == 'delete':
self.state_machine.pop(key, None)
# 调用回调函数
for callback in self.state_machine_callbacks:
try:
callback(command, self.state_machine)
except Exception as e:
self.logger.error(f"状态机回调异常: {e}")
def append_entry(self, command: Dict[str, Any]) -> bool:
"""追加日志条目(仅Leader可调用)"""
with self.lock:
if self.state != NodeState.LEADER:
return False
entry = LogEntry(
term=self.current_term,
index=len(self.log),
command=command
)
self.log.append(entry)
self.logger.info(f"追加日志条目: {entry.index}, 命令: {command}")
return True
def _send_vote_request(self, node_id: str, request: VoteRequest) -> Optional[VoteResponse]:
"""发送投票请求(模拟网络调用)"""
# 这里应该实际发送网络请求
# 简化实现,模拟网络调用
time.sleep(0.01) # 模拟网络延迟
# 模拟响应
return VoteResponse(
term=self.current_term,
vote_granted=random.random() > 0.3 # 70%概率同意投票
)
def _send_append_entries_request(self, node_id: str,
request: AppendEntriesRequest) -> Optional[AppendEntriesResponse]:
"""发送AppendEntries请求(模拟网络调用)"""
# 这里应该实际发送网络请求
# 简化实现,模拟网络调用
time.sleep(0.005) # 模拟网络延迟
# 模拟响应
return AppendEntriesResponse(
term=self.current_term,
success=random.random() > 0.1, # 90%概率成功
match_index=request.prev_log_index + len(request.entries)
)
def handle_vote_request(self, request: VoteRequest) -> VoteResponse:
"""处理投票请求"""
with self.lock:
# 如果请求的任期小于当前任期,拒绝投票
if request.term < self.current_term:
return VoteResponse(term=self.current_term, vote_granted=False)
# 如果请求的任期大于当前任期,更新任期并转为跟随者
if request.term > self.current_term:
self.current_term = request.term
self.voted_for = None
self._become_follower(request.term)
# 检查是否可以投票
vote_granted = False
if (self.voted_for is None or self.voted_for == request.candidate_id):
# 检查候选人的日志是否至少和自己一样新
last_log_index = len(self.log) - 1
last_log_term = self.log[last_log_index].term if self.log else 0
if (request.last_log_term > last_log_term or
(request.last_log_term == last_log_term and
request.last_log_index >= last_log_index)):
vote_granted = True
self.voted_for = request.candidate_id
self.last_heartbeat = time.time()
return VoteResponse(term=self.current_term, vote_granted=vote_granted)
def handle_append_entries(self, request: AppendEntriesRequest) -> AppendEntriesResponse:
"""处理AppendEntries请求"""
with self.lock:
# 如果请求的任期小于当前任期,拒绝
if request.term < self.current_term:
return AppendEntriesResponse(term=self.current_term, success=False)
# 更新任期并转为跟随者
if request.term >= self.current_term:
self.current_term = request.term
self._become_follower(request.term)
self.last_heartbeat = time.time()
# 检查前一个日志条目是否匹配
if (request.prev_log_index >= 0 and
(request.prev_log_index >= len(self.log) or
self.log[request.prev_log_index].term != request.prev_log_term)):
return AppendEntriesResponse(term=self.current_term, success=False)
# 删除冲突的日志条目
if request.entries:
# 找到第一个冲突的条目
conflict_index = request.prev_log_index + 1
for i, entry in enumerate(request.entries):
index = conflict_index + i
if index < len(self.log) and self.log[index].term != entry.term:
# 删除从这个位置开始的所有条目
self.log = self.log[:index]
break
# 追加新条目
for entry in request.entries:
if conflict_index >= len(self.log):
self.log.append(entry)
conflict_index += 1
# 更新commit_index
if request.leader_commit > self.commit_index:
self.commit_index = min(request.leader_commit, len(self.log) - 1)
self._apply_committed_entries()
return AppendEntriesResponse(
term=self.current_term,
success=True,
match_index=request.prev_log_index + len(request.entries)
)
def get_state(self) -> Dict[str, Any]:
"""获取节点状态"""
with self.lock:
return {
'node_id': self.node_id,
'state': self.state.value,
'current_term': self.current_term,
'voted_for': self.voted_for,
'log_length': len(self.log),
'commit_index': self.commit_index,
'last_applied': self.last_applied,
'state_machine': dict(self.state_machine)
}
def add_state_machine_callback(self, callback: Callable):
"""添加状态机回调函数"""
self.state_machine_callbacks.append(callback)
def main():
# 示例用法
cluster_nodes = ['node1', 'node2', 'node3']
# 创建Raft节点
nodes = []
for i, node_id in enumerate(cluster_nodes):
node = RaftNode(node_id, cluster_nodes, port=8000 + i)
nodes.append(node)
# 启动所有节点
for node in nodes:
node.start()
# 等待选举完成
time.sleep(2)
# 查找Leader
leader = None
for node in nodes:
state = node.get_state()
print(f"节点 {state['node_id']}: {state['state']}, 任期: {state['current_term']}")
if state['state'] == 'leader':
leader = node
if leader:
print(f"\nLeader: {leader.node_id}")
# 在Leader上追加一些日志条目
commands = [
{'operation': 'set', 'key': 'user:1', 'value': 'Alice'},
{'operation': 'set', 'key': 'user:2', 'value': 'Bob'},
{'operation': 'set', 'key': 'counter', 'value': 100},
{'operation': 'delete', 'key': 'user:1'}
]
for command in commands:
success = leader.append_entry(command)
print(f"追加命令: {command}, 结果: {success}")
time.sleep(0.5)
# 等待日志复制
time.sleep(2)
# 查看所有节点的状态机
print("\n所有节点的状态机:")
for node in nodes:
state = node.get_state()
print(f"节点 {state['node_id']}: {state['state_machine']}")
# 停止所有节点
for node in nodes:
node.stop()
if __name__ == "__main__":
main()
故障检测与恢复
1. 集群健康监控系统
#!/bin/bash
# scripts/cluster_health_monitor.sh
# 集群健康监控脚本
# 配置文件路径
CONFIG_FILE="${CONFIG_FILE:-/etc/distributed_db/cluster.conf}"
LOG_FILE="${LOG_FILE:-/var/log/cluster_health.log}"
ALERT_SCRIPT="${ALERT_SCRIPT:-/usr/local/bin/send_alert.sh}"
# 监控间隔(秒)
MONITOR_INTERVAL="${MONITOR_INTERVAL:-30}"
# 故障阈值
FAILURE_THRESHOLD="${FAILURE_THRESHOLD:-3}"
RESPONSE_TIMEOUT="${RESPONSE_TIMEOUT:-5}"
# 全局变量
declare -A NODE_FAILURE_COUNT
declare -A NODE_STATUS
declare -A SHARD_STATUS
# 日志函数
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$LOG_FILE"
}
error() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] ERROR: $1" | tee -a "$LOG_FILE" >&2
}
# 加载配置
load_config() {
if [[ ! -f "$CONFIG_FILE" ]]; then
error "配置文件不存在: $CONFIG_FILE"
exit 1
fi
source "$CONFIG_FILE"
log "加载配置文件: $CONFIG_FILE"
}
# 检查节点健康状态
check_node_health() {
local node_id="$1"
local host="$2"
local port="$3"
# TCP连接检查
if timeout "$RESPONSE_TIMEOUT" bash -c "</dev/tcp/$host/$port" 2>/dev/null; then
# 应用层健康检查
local health_response
health_response=$(curl -s --max-time "$RESPONSE_TIMEOUT" \
"http://$host:$((port + 1000))/health" 2>/dev/null)
if [[ "$health_response" == *"healthy"* ]]; then
NODE_STATUS["$node_id"]="healthy"
NODE_FAILURE_COUNT["$node_id"]=0
return 0
else
NODE_STATUS["$node_id"]="unhealthy"
((NODE_FAILURE_COUNT["$node_id"]++))
return 1
fi
else
NODE_STATUS["$node_id"]="unreachable"
((NODE_FAILURE_COUNT["$node_id"]++))
return 1
fi
}
# 检查分片状态
check_shard_health() {
local shard_id="$1"
local primary_node="$2"
local replica_nodes="$3"
local healthy_replicas=0
local total_replicas=0
# 检查主节点
if [[ "${NODE_STATUS[$primary_node]}" == "healthy" ]]; then
((healthy_replicas++))
fi
((total_replicas++))
# 检查副本节点
IFS=',' read -ra REPLICAS <<< "$replica_nodes"
for replica in "${REPLICAS[@]}"; do
if [[ "${NODE_STATUS[$replica]}" == "healthy" ]]; then
((healthy_replicas++))
fi
((total_replicas++))
done
# 计算健康比例
local health_ratio=$((healthy_replicas * 100 / total_replicas))
if [[ $health_ratio -ge 67 ]]; then # 至少2/3节点健康
SHARD_STATUS["$shard_id"]="healthy"
elif [[ $health_ratio -ge 50 ]]; then
SHARD_STATUS["$shard_id"]="degraded"
else
SHARD_STATUS["$shard_id"]="critical"
fi
log "分片 $shard_id 健康状态: ${SHARD_STATUS[$shard_id]} ($healthy_replicas/$total_replicas)"
}
# 故障恢复
handle_node_failure() {
local failed_node="$1"
log "处理节点故障: $failed_node"
# 发送告警
if [[ -x "$ALERT_SCRIPT" ]]; then
"$ALERT_SCRIPT" "NODE_FAILURE" "$failed_node" "节点 $failed_node 连续失败 ${NODE_FAILURE_COUNT[$failed_node]} 次"
fi
# 尝试自动恢复
attempt_node_recovery "$failed_node"
# 如果是主节点故障,触发故障转移
if is_primary_node "$failed_node"; then
trigger_failover "$failed_node"
fi
}
# 尝试节点恢复
attempt_node_recovery() {
local node_id="$1"
log "尝试恢复节点: $node_id"
# 重启节点服务
if systemctl is-active --quiet "distributed-db-$node_id"; then
systemctl restart "distributed-db-$node_id"
sleep 10
# 重新检查健康状态
local host port
get_node_info "$node_id" host port
if check_node_health "$node_id" "$host" "$port"; then
log "节点恢复成功: $node_id"
return 0
fi
fi
log "节点恢复失败: $node_id"
return 1
}
# 检查是否为主节点
is_primary_node() {
local node_id="$1"
# 从配置中查找该节点是否为某个分片的主节点
grep -q "primary.*$node_id" "$CONFIG_FILE"
}
# 触发故障转移
trigger_failover() {
local failed_primary="$1"
log "触发故障转移: $failed_primary"
# 查找受影响的分片
local affected_shards
affected_shards=$(grep -l "primary.*$failed_primary" "$CONFIG_FILE" | \
sed 's/.*shard_\([0-9]*\).*/\1/')
for shard_id in $affected_shards; do
log "为分片 $shard_id 执行故障转移"
# 选择新的主节点(选择第一个健康的副本)
local replica_nodes
replica_nodes=$(get_shard_replicas "$shard_id")
IFS=',' read -ra REPLICAS <<< "$replica_nodes"
for replica in "${REPLICAS[@]}"; do
if [[ "${NODE_STATUS[$replica]}" == "healthy" ]]; then
promote_to_primary "$replica" "$shard_id"
break
fi
done
done
}
# 提升副本为主节点
promote_to_primary() {
local new_primary="$1"
local shard_id="$2"
log "提升节点 $new_primary 为分片 $shard_id 的主节点"
# 调用管理API进行故障转移
local host port
get_node_info "$new_primary" host port
curl -X POST "http://$host:$((port + 1000))/admin/promote" \
-H "Content-Type: application/json" \
-d "{\"shard_id\": \"$shard_id\"}" \
--max-time 10 2>/dev/null
if [[ $? -eq 0 ]]; then
log "故障转移成功: 分片 $shard_id, 新主节点: $new_primary"
# 更新配置文件
update_config_primary "$shard_id" "$new_primary"
# 发送告警
if [[ -x "$ALERT_SCRIPT" ]]; then
"$ALERT_SCRIPT" "FAILOVER_SUCCESS" "$shard_id" \
"分片 $shard_id 故障转移成功,新主节点: $new_primary"
fi
else
error "故障转移失败: 分片 $shard_id, 目标节点: $new_primary"
fi
}
# 获取节点信息
get_node_info() {
local node_id="$1"
local -n host_ref="$2"
local -n port_ref="$3"
# 从配置文件解析节点信息
local node_line
node_line=$(grep "^$node_id=" "$CONFIG_FILE")
if [[ -n "$node_line" ]]; then
host_ref=$(echo "$node_line" | cut -d'=' -f2 | cut -d':' -f1)
port_ref=$(echo "$node_line" | cut -d'=' -f2 | cut -d':' -f2)
fi
}
# 获取分片副本节点
get_shard_replicas() {
local shard_id="$1"
grep "^shard_${shard_id}_replicas=" "$CONFIG_FILE" | cut -d'=' -f2
}
# 更新配置文件中的主节点
update_config_primary() {
local shard_id="$1"
local new_primary="$2"
sed -i "s/^shard_${shard_id}_primary=.*/shard_${shard_id}_primary=$new_primary/" "$CONFIG_FILE"
}
# 生成健康报告
generate_health_report() {
local report_file="/tmp/cluster_health_$(date +%Y%m%d_%H%M%S).json"
cat > "$report_file" << EOF
{
"timestamp": "$(date -Iseconds)",
"cluster_status": "$(get_cluster_status)",
"nodes": {
EOF
local first_node=true
for node_id in "${!NODE_STATUS[@]}"; do
if [[ "$first_node" == true ]]; then
first_node=false
else
echo "," >> "$report_file"
fi
cat >> "$report_file" << EOF
"$node_id": {
"status": "${NODE_STATUS[$node_id]}",
"failure_count": ${NODE_FAILURE_COUNT[$node_id]:-0}
}
EOF
done
cat >> "$report_file" << EOF
},
"shards": {
EOF
local first_shard=true
for shard_id in "${!SHARD_STATUS[@]}"; do
if [[ "$first_shard" == true ]]; then
first_shard=false
else
echo "," >> "$report_file"
fi
cat >> "$report_file" << EOF
"$shard_id": {
"status": "${SHARD_STATUS[$shard_id]}"
}
EOF
done
cat >> "$report_file" << EOF
}
}
EOF
echo "$report_file"
}
# 获取集群整体状态
get_cluster_status() {
local healthy_nodes=0
local total_nodes=0
local critical_shards=0
for status in "${NODE_STATUS[@]}"; do
if [[ "$status" == "healthy" ]]; then
((healthy_nodes++))
fi
((total_nodes++))
done
for status in "${SHARD_STATUS[@]}"; do
if [[ "$status" == "critical" ]]; then
((critical_shards++))
fi
done
if [[ $critical_shards -gt 0 ]]; then
echo "critical"
elif [[ $((healthy_nodes * 100 / total_nodes)) -lt 80 ]]; then
echo "degraded"
else
echo "healthy"
fi
}
# 主监控循环
main_monitor_loop() {
log "启动集群健康监控"
while true; do
log "开始健康检查循环"
# 检查所有节点
while IFS='=' read -r node_id node_addr; do
if [[ -n "$node_id" && "$node_id" != \#* ]]; then
local host port
host=$(echo "$node_addr" | cut -d':' -f1)
port=$(echo "$node_addr" | cut -d':' -f2)
if ! check_node_health "$node_id" "$host" "$port"; then
if [[ ${NODE_FAILURE_COUNT[$node_id]} -ge $FAILURE_THRESHOLD ]]; then
handle_node_failure "$node_id"
fi
fi
fi
done < <(grep "^[^#].*=" "$CONFIG_FILE" | grep -v "shard_")
# 检查所有分片
while IFS='=' read -r shard_key shard_primary; do
if [[ "$shard_key" == shard_*_primary ]]; then
local shard_id
shard_id=$(echo "$shard_key" | sed 's/shard_\(.*\)_primary/\1/')
local replica_nodes
replica_nodes=$(get_shard_replicas "$shard_id")
check_shard_health "$shard_id" "$shard_primary" "$replica_nodes"
fi
done < <(grep "^shard_.*_primary=" "$CONFIG_FILE")
# 生成健康报告
local report_file
report_file=$(generate_health_report)
log "健康报告已生成: $report_file"
# 等待下次检查
sleep "$MONITOR_INTERVAL"
done
}
# 信号处理
cleanup() {
log "收到退出信号,正在清理..."
exit 0
}
trap cleanup SIGTERM SIGINT
# 主函数
main() {
# 检查依赖
for cmd in curl timeout systemctl; do
if ! command -v "$cmd" &> /dev/null; then
error "缺少依赖命令: $cmd"
exit 1
fi
done
# 加载配置
load_config
# 创建日志目录
mkdir -p "$(dirname "$LOG_FILE")"
# 启动监控
main_monitor_loop
}
# 如果直接执行脚本
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
main "$@"
fi
性能优化与调优
1. 查询优化器
#!/usr/bin/env python3
# src/optimizer/query_optimizer.py
import re
import time
import logging
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import sqlparse
from sqlparse.sql import Statement, Token
from sqlparse.tokens import Keyword, Name
class QueryType(Enum):
SELECT = "select"
INSERT = "insert"
UPDATE = "update"
DELETE = "delete"
JOIN = "join"
@dataclass
class QueryPlan:
query_id: str
original_query: str
optimized_query: str
execution_plan: List[Dict[str, Any]]
estimated_cost: float
target_shards: List[str]
optimization_hints: List[str]
@dataclass
class TableStats:
table_name: str
row_count: int
avg_row_size: int
index_info: Dict[str, Any]
partition_info: Dict[str, Any]
last_updated: float
class DistributedQueryOptimizer:
def __init__(self, shard_router, stats_collector):
self.shard_router = shard_router
self.stats_collector = stats_collector
self.logger = self._setup_logging()
# 查询缓存
self.query_cache: Dict[str, QueryPlan] = {}
self.cache_size_limit = 1000
# 统计信息缓存
self.table_stats: Dict[str, TableStats] = {}
self.stats_ttl = 3600 # 1小时
# 优化规则
self.optimization_rules = [
self._optimize_predicate_pushdown,
self._optimize_join_order,
self._optimize_index_selection,
self._optimize_partition_pruning,
self._optimize_aggregation_pushdown
]
def _setup_logging(self) -> logging.Logger:
"""设置日志记录"""
logger = logging.getLogger('QueryOptimizer')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def optimize_query(self, query: str, query_params: Dict[str, Any] = None) -> QueryPlan:
"""优化查询"""
query_id = self._generate_query_id(query, query_params)
# 检查缓存
if query_id in self.query_cache:
self.logger.info(f"使用缓存的查询计划: {query_id}")
return self.query_cache[query_id]
# 解析查询
parsed_query = self._parse_query(query)
if not parsed_query:
return self._create_fallback_plan(query_id, query)
# 分析查询
query_analysis = self._analyze_query(parsed_query, query_params)
# 应用优化规则
optimized_query = query
optimization_hints = []
for rule in self.optimization_rules:
try:
result = rule(parsed_query, query_analysis, query_params)
if result:
optimized_query = result.get('query', optimized_query)
optimization_hints.extend(result.get('hints', []))
except Exception as e:
self.logger.error(f"优化规则执行失败: {rule.__name__}, 错误: {e}")
# 生成执行计划
execution_plan = self._generate_execution_plan(optimized_query, query_analysis)
# 估算成本
estimated_cost = self._estimate_query_cost(execution_plan, query_analysis)
# 确定目标分片
target_shards = self._determine_target_shards(query_analysis, query_params)
# 创建查询计划
query_plan = QueryPlan(
query_id=query_id,
original_query=query,
optimized_query=optimized_query,
execution_plan=execution_plan,
estimated_cost=estimated_cost,
target_shards=target_shards,
optimization_hints=optimization_hints
)
# 缓存查询计划
self._cache_query_plan(query_plan)
return query_plan
def _parse_query(self, query: str) -> Optional[Statement]:
"""解析SQL查询"""
try:
parsed = sqlparse.parse(query)
return parsed[0] if parsed else None
except Exception as e:
self.logger.error(f"查询解析失败: {e}")
return None
def _analyze_query(self, parsed_query: Statement,
query_params: Dict[str, Any] = None) -> Dict[str, Any]:
"""分析查询结构"""
analysis = {
'query_type': self._get_query_type(parsed_query),
'tables': self._extract_tables(parsed_query),
'columns': self._extract_columns(parsed_query),
'where_conditions': self._extract_where_conditions(parsed_query),
'join_conditions': self._extract_join_conditions(parsed_query),
'group_by': self._extract_group_by(parsed_query),
'order_by': self._extract_order_by(parsed_query),
'limit': self._extract_limit(parsed_query),
'aggregations': self._extract_aggregations(parsed_query)
}
return analysis
def _get_query_type(self, parsed_query: Statement) -> QueryType:
"""获取查询类型"""
first_token = parsed_query.token_first(skip_ws=True, skip_cm=True)
if first_token and first_token.ttype is Keyword:
keyword = first_token.value.upper()
if keyword == 'SELECT':
return QueryType.SELECT
elif keyword == 'INSERT':
return QueryType.INSERT
elif keyword == 'UPDATE':
return QueryType.UPDATE
elif keyword == 'DELETE':
return QueryType.DELETE
return QueryType.SELECT
def _extract_tables(self, parsed_query: Statement) -> List[str]:
"""提取表名"""
tables = []
# 简化实现,实际需要更复杂的解析逻辑
query_str = str(parsed_query).upper()
# 查找FROM子句中的表名
from_match = re.search(r'FROM\s+(\w+)', query_str)
if from_match:
tables.append(from_match.group(1).lower())
# 查找JOIN子句中的表名
join_matches = re.findall(r'JOIN\s+(\w+)', query_str)
for match in join_matches:
tables.append(match.lower())
return list(set(tables))
def _extract_columns(self, parsed_query: Statement) -> List[str]:
"""提取列名"""
columns = []
# 简化实现
query_str = str(parsed_query)
# 提取SELECT子句中的列名
select_match = re.search(r'SELECT\s+(.*?)\s+FROM', query_str, re.IGNORECASE | re.DOTALL)
if select_match:
select_clause = select_match.group(1)
# 简单分割,实际需要更复杂的解析
for col in select_clause.split(','):
col = col.strip()
if col and col != '*':
columns.append(col)
return columns
def _extract_where_conditions(self, parsed_query: Statement) -> List[Dict[str, Any]]:
"""提取WHERE条件"""
conditions = []
query_str = str(parsed_query)
where_match = re.search(r'WHERE\s+(.*?)(?:\s+GROUP\s+BY|\s+ORDER\s+BY|\s+LIMIT|$)',
query_str, re.IGNORECASE | re.DOTALL)
if where_match:
where_clause = where_match.group(1).strip()
# 简化解析,实际需要更复杂的逻辑
# 查找等值条件
eq_conditions = re.findall(r'(\w+)\s*=\s*[\'"]?([^\'"\s]+)[\'"]?', where_clause)
for column, value in eq_conditions:
conditions.append({
'column': column,
'operator': '=',
'value': value,
'type': 'equality'
})
# 查找范围条件
range_conditions = re.findall(r'(\w+)\s*(>|<|>=|<=)\s*[\'"]?([^\'"\s]+)[\'"]?', where_clause)
for column, operator, value in range_conditions:
conditions.append({
'column': column,
'operator': operator,
'value': value,
'type': 'range'
})
return conditions
def _extract_join_conditions(self, parsed_query: Statement) -> List[Dict[str, Any]]:
"""提取JOIN条件"""
joins = []
query_str = str(parsed_query)
join_matches = re.findall(r'(\w+\s+)?JOIN\s+(\w+)\s+ON\s+(.*?)(?:\s+(?:INNER|LEFT|RIGHT|FULL)\s+JOIN|\s+WHERE|\s+GROUP\s+BY|\s+ORDER\s+BY|$)',
query_str, re.IGNORECASE)
for join_type, table, condition in join_matches:
joins.append({
'type': join_type.strip() if join_type else 'INNER',
'table': table,
'condition': condition.strip()
})
return joins
def _extract_group_by(self, parsed_query: Statement) -> List[str]:
"""提取GROUP BY列"""
query_str = str(parsed_query)
group_match = re.search(r'GROUP\s+BY\s+(.*?)(?:\s+ORDER\s+BY|\s+LIMIT|$)',
query_str, re.IGNORECASE)
if group_match:
group_clause = group_match.group(1).strip()
return [col.strip() for col in group_clause.split(',')]
return []
def _extract_order_by(self, parsed_query: Statement) -> List[Dict[str, str]]:
"""提取ORDER BY列"""
query_str = str(parsed_query)
order_match = re.search(r'ORDER\s+BY\s+(.*?)(?:\s+LIMIT|$)',
query_str, re.IGNORECASE)
order_columns = []
if order_match:
order_clause = order_match.group(1).strip()
for item in order_clause.split(','):
item = item.strip()
if ' DESC' in item.upper():
column = item.replace(' DESC', '').replace(' desc', '').strip()
order_columns.append({'column': column, 'direction': 'DESC'})
else:
column = item.replace(' ASC', '').replace(' asc', '').strip()
order_columns.append({'column': column, 'direction': 'ASC'})
return order_columns
def _extract_limit(self, parsed_query: Statement) -> Optional[int]:
"""提取LIMIT值"""
query_str = str(parsed_query)
limit_match = re.search(r'LIMIT\s+(\d+)', query_str, re.IGNORECASE)
if limit_match:
return int(limit_match.group(1))
return None
def _extract_aggregations(self, parsed_query: Statement) -> List[Dict[str, str]]:
"""提取聚合函数"""
aggregations = []
query_str = str(parsed_query)
# 查找聚合函数
agg_matches = re.findall(r'(COUNT|SUM|AVG|MIN|MAX)\s*\(\s*([^)]+)\s*\)',
query_str, re.IGNORECASE)
for func, column in agg_matches:
aggregations.append({
'function': func.upper(),
'column': column.strip()
})
return aggregations
def _optimize_predicate_pushdown(self, parsed_query: Statement,
analysis: Dict[str, Any],
query_params: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
"""谓词下推优化"""
hints = []
# 检查WHERE条件是否可以下推到分片
where_conditions = analysis.get('where_conditions', [])
tables = analysis.get('tables', [])
for condition in where_conditions:
column = condition['column']
# 检查是否为分片键
for table in tables:
if table in self.shard_router.shard_keys:
shard_keys = self.shard_router.shard_keys[table]
if column in shard_keys:
hints.append(f"谓词下推: {column} 条件可以下推到分片级别")
return {'hints': hints} if hints else None
def _optimize_join_order(self, parsed_query: Statement,
analysis: Dict[str, Any],
query_params: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
"""JOIN顺序优化"""
hints = []
joins = analysis.get('join_conditions', [])
if len(joins) > 1:
# 获取表统计信息
tables = analysis.get('tables', [])
table_sizes = {}
for table in tables:
stats = self._get_table_stats(table)
if stats:
table_sizes[table] = stats.row_count
# 建议小表在前的JOIN顺序
if table_sizes:
sorted_tables = sorted(table_sizes.items(), key=lambda x: x[1])
hints.append(f"建议JOIN顺序: {' -> '.join([t[0] for t in sorted_tables])}")
return {'hints': hints} if hints else None
def _optimize_index_selection(self, parsed_query: Statement,
analysis: Dict[str, Any],
query_params: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
"""索引选择优化"""
hints = []
where_conditions = analysis.get('where_conditions', [])
tables = analysis.get('tables', [])
for table in tables:
stats = self._get_table_stats(table)
if not stats:
continue
# 检查WHERE条件中的列是否有索引
for condition in where_conditions:
column = condition['column']
if column in stats.index_info:
index_info = stats.index_info[column]
hints.append(f"使用索引: {table}.{column} ({index_info.get('type', 'unknown')})")
else:
hints.append(f"建议创建索引: {table}.{column}")
return {'hints': hints} if hints else None
def _optimize_partition_pruning(self, parsed_query: Statement,
analysis: Dict[str, Any],
query_params: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
"""分区裁剪优化"""
hints = []
where_conditions = analysis.get('where_conditions', [])
tables = analysis.get('tables', [])
for table in tables:
stats = self._get_table_stats(table)
if not stats or not stats.partition_info:
continue
partition_column = stats.partition_info.get('column')
if not partition_column:
continue
# 检查WHERE条件中是否包含分区列
for condition in where_conditions:
if condition['column'] == partition_column:
hints.append(f"分区裁剪: 基于 {partition_column} 条件可以裁剪分区")
return {'hints': hints} if hints else None
def _optimize_aggregation_pushdown(self, parsed_query: Statement,
analysis: Dict[str, Any],
query_params: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
"""聚合下推优化"""
hints = []
aggregations = analysis.get('aggregations', [])
group_by = analysis.get('group_by', [])
if aggregations:
# 检查是否可以将聚合下推到分片
pushdown_possible = True
for agg in aggregations:
func = agg['function']
if func not in ['COUNT', 'SUM', 'MIN', 'MAX']:
pushdown_possible = False
break
if pushdown_possible:
hints.append("聚合下推: 可以将聚合操作下推到各个分片并在协调节点合并结果")
return {'hints': hints} if hints else None
def _generate_execution_plan(self, query: str, analysis: Dict[str, Any]) -> List[Dict[str, Any]]:
"""生成执行计划"""
plan = []
query_type = analysis.get('query_type', QueryType.SELECT)
tables = analysis.get('tables', [])
joins = analysis.get('join_conditions', [])
# 表扫描步骤
for table in tables:
step = {
'step_type': 'table_scan',
'table': table,
'estimated_rows': self._estimate_table_rows(table, analysis),
'cost': 1.0
}
plan.append(step)
# JOIN步骤
for join in joins:
step = {
'step_type': 'join',
'join_type': join['type'],
'table': join['table'],
'condition': join['condition'],
'estimated_rows': 1000, # 简化估算
'cost': 2.0
}
plan.append(step)
# 聚合步骤
if analysis.get('aggregations') or analysis.get('group_by'):
step = {
'step_type': 'aggregation',
'estimated_rows': 100, # 简化估算
'cost': 1.5
}
plan.append(step)
# 排序步骤
if analysis.get('order_by'):
step = {
'step_type': 'sort',
'columns': analysis['order_by'],
'estimated_rows': 1000, # 简化估算
'cost': 2.5
}
plan.append(step)
return plan
def _estimate_query_cost(self, execution_plan: List[Dict[str, Any]],
analysis: Dict[str, Any]) -> float:
"""估算查询成本"""
total_cost = 0.0
for step in execution_plan:
step_cost = step.get('cost', 1.0)
estimated_rows = step.get('estimated_rows', 1000)
# 基于行数调整成本
row_factor = min(estimated_rows / 1000.0, 10.0) # 最大10倍
total_cost += step_cost * row_factor
return total_cost
def _determine_target_shards(self, analysis: Dict[str, Any],
query_params: Dict[str, Any] = None) -> List[str]:
"""确定目标分片"""
tables = analysis.get('tables', [])
where_conditions = analysis.get('where_conditions', [])
target_shards = set()
for table in tables:
# 构建查询参数
table_params = {}
for condition in where_conditions:
table_params[condition['column']] = condition['value']
# 如果有查询参数,合并
if query_params:
table_params.update(query_params)
# 路由到分片
shards = self.shard_router.route_query(table, table_params)
target_shards.update(shards)
return list(target_shards)
def _get_table_stats(self, table_name: str) -> Optional[TableStats]:
"""获取表统计信息"""
if table_name in self.table_stats:
stats = self.table_stats[table_name]
if time.time() - stats.last_updated < self.stats_ttl:
return stats
# 从统计收集器获取最新统计信息
try:
stats_data = self.stats_collector.get_table_stats(table_name)
if stats_data:
stats = TableStats(
table_name=table_name,
row_count=stats_data.get('row_count', 0),
avg_row_size=stats_data.get('avg_row_size', 0),
index_info=stats_data.get('index_info', {}),
partition_info=stats_data.get('partition_info', {}),
last_updated=time.time()
)
self.table_stats[table_name] = stats
return stats
except Exception as e:
self.logger.error(f"获取表统计信息失败: {table_name}, 错误: {e}")
return None
def _estimate_table_rows(self, table_name: str, analysis: Dict[str, Any]) -> int:
"""估算表行数"""
stats = self._get_table_stats(table_name)
if stats:
base_rows = stats.row_count
# 根据WHERE条件调整估算
where_conditions = analysis.get('where_conditions', [])
selectivity = 1.0
for condition in where_conditions:
if condition['type'] == 'equality':
selectivity *= 0.1 # 等值条件选择性10%
elif condition['type'] == 'range':
selectivity *= 0.3 # 范围条件选择性30%
return int(base_rows * selectivity)
return 1000 # 默认估算
def _generate_query_id(self, query: str, query_params: Dict[str, Any] = None) -> str:
"""生成查询ID"""
import hashlib
content = query
if query_params:
content += str(sorted(query_params.items()))
return hashlib.md5(content.encode()).hexdigest()[:16]
def _create_fallback_plan(self, query_id: str, query: str) -> QueryPlan:
"""创建回退查询计划"""
return QueryPlan(
query_id=query_id,
original_query=query,
optimized_query=query,
execution_plan=[{
'step_type': 'fallback',
'estimated_rows': 1000,
'cost': 10.0
}],
estimated_cost=10.0,
target_shards=[],
optimization_hints=['查询解析失败,使用回退计划']
)
def _cache_query_plan(self, query_plan: QueryPlan):
"""缓存查询计划"""
if len(self.query_cache) >= self.cache_size_limit:
# 移除最旧的缓存项
oldest_key = next(iter(self.query_cache))
del self.query_cache[oldest_key]
self.query_cache[query_plan.query_id] = query_plan
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return {
'cache_size': len(self.query_cache),
'cache_limit': self.cache_size_limit,
'hit_rate': 0.85 # 简化实现
}
def main():
# 示例用法
from sharding.shard_router import ShardRouter
class MockStatsCollector:
def get_table_stats(self, table_name):
return {
'row_count': 100000,
'avg_row_size': 256,
'index_info': {
'user_id': {'type': 'btree'},
'email': {'type': 'hash'}
},
'partition_info': {
'column': 'created_date',
'type': 'range'
}
}
# 创建优化器
router = ShardRouter('config/distributed_database.yaml')
stats_collector = MockStatsCollector()
optimizer = DistributedQueryOptimizer(router, stats_collector)
# 测试查询优化
test_queries = [
"SELECT * FROM users WHERE user_id = 12345",
"SELECT u.name, o.total FROM users u JOIN orders o ON u.user_id = o.user_id WHERE u.user_id = 12345",
"SELECT category_id, COUNT(*) FROM products GROUP BY category_id ORDER BY COUNT(*) DESC LIMIT 10"
]
for query in test_queries:
print(f"\n原始查询: {query}")
plan = optimizer.optimize_query(query)
print(f"查询ID: {plan.query_id}")
print(f"优化后查询: {plan.optimized_query}")
print(f"目标分片: {plan.target_shards}")
print(f"估算成本: {plan.estimated_cost:.2f}")
print("优化建议:")
for hint in plan.optimization_hints:
print(f" - {hint}")
print("执行计划:")
for i, step in enumerate(plan.execution_plan):
print(f" {i+1}. {step['step_type']}: 估算行数={step['estimated_rows']}, 成本={step['cost']}")
if __name__ == "__main__":
main()
总结
分布式数据库架构设计是一个复杂的系统工程,需要综合考虑数据分片、一致性保证、故障恢复、性能优化等多个方面。本文提供的完整方案包括:
核心架构要素
-
数据分片策略
- 哈希分片:适用于均匀分布的数据
- 范围分片:适用于有序数据和范围查询
- 一致性哈希:提供良好的扩展性
- 目录分片:灵活的分片映射管理
-
一致性协议
- Raft共识算法:保证强一致性
- 两阶段提交:分布式事务协调
- 最终一致性:高可用性场景
-
故障检测与恢复
- 健康监控:实时检测节点状态
- 自动故障转移:最小化服务中断
- 数据恢复:保证数据完整性
-
性能优化
- 查询优化器:智能查询路由和优化
- 缓存策略:减少网络开销
- 负载均衡:均匀分布查询负载
最佳实践
-
架构设计
- 合理选择分片策略
- 设计容错机制
- 考虑扩展性需求
-
运维管理
- 完善的监控体系
- 自动化运维工具
- 定期性能调优
-
数据管理
- 合理的数据分布
- 及时的统计信息更新
- 有效的缓存策略
通过本文提供的架构设计和实现方案,可以构建一个高可用、高性能、可扩展的分布式数据库系统,满足现代应用的复杂需求。