import numpy as np
import pandas as pd
import pandapower as pp
from pandapower.networks import case33bw
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import os

# 配置参数
CONFIG = {
    # 网络参数
    'BASE_KV': 12.66,  # 基准电压（kV）
    'BASE_MVA': 10.0,  # 基准容量（MVA）
    'V_MAX': 1.05,     # 最大允许电压（标幺值）
    'V_MIN': 0.95,     # 最小允许电压（标幺值）
    
    # PV参数
    'PV_CAPACITY': 0.5,      # PV额定容量（MW）
    'PV_NODES': [6, 12, 18, 22, 25, 32],  # PV安装节点
    
    # SVC参数
    'SVC_CAPACITY': 0.8,    # SVC容量（MVar）
    'SVC_NODES': [10, 30],  # SVC安装节点
    'SVC_MIN_MVAR': -0.8,   # SVC最小无功出力（MVar）
    'SVC_MAX_MVAR': 0.8,    # SVC最大无功出力（MVar）
    
    # 训练参数
    'EPISODES': 1000,   # 训练回合数
    'T_MAX': 24,        # 每回合时间步数（小时）
    'BATCH_SIZE': 64,   # 批次大小
    'MEMORY_SIZE': 100000,  # 经验回放池大小
    'GAMMA': 0.99,      # 折扣因子
    'LR': 1e-4,         # 学习率
    'TAU': 0.001,       # 软更新参数
    'NOISE_SCALE': 0.1, # 探索噪声系数
    
    # 保存和日志参数
    'SAVE_INTERVAL': 2,  # 模型保存间隔
    'LOG_INTERVAL': 100,    # 日志打印间隔
    'SAVE_DIR': 'saved_models_final'  # 模型保存目录
}

# 使用CONFIG字典作为配置
config = CONFIG

# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 初始化电网模型
net = case33bw()

# 确保所有母线索引为整数类型
net.bus.index = net.bus.index.astype(int)
net.load.bus = net.load.bus.astype(int)
if len(net.gen) > 0:
    net.gen.bus = net.gen.bus.astype(int)
net.ext_grid.bus = net.ext_grid.bus.astype(int)

# 创建变压器
pp.create_transformer_from_parameters(net,
    hv_bus=0,  # 高压侧连接到节点0
    lv_bus=1,  # 低压侧连接到节点1
    sn_mva=10.0,  # 额定容量
    vn_hv_kv=12.66,  # 高压侧额定电压
    vn_lv_kv=12.66,  # 低压侧额定电压
    vk_percent=6.0,  # 短路阻抗
    vkr_percent=1.0,  # 短路损耗
    pfe_kw=10.0,  # 铁损
    i0_percent=0.1,  # 空载电流
    tap_neutral=0,  # 中性档位
    tap_min=-2,  # 最小档位
    tap_max=2,  # 最大档位
    tap_step_percent=2.5,  # 每档调节百分比
    tap_side='hv',  # 调压绕组在高压侧
    tap_phase_shifter=False,  # 不是移相变压器
    name='Trafo1',
    tap_pos=0,  # 初始档位
    in_service=True  # 投运状态
)

# 在指定节点添加光伏
pv_nodes = config['PV_NODES']
for pv_bus in pv_nodes:
    bus_idx = int(pv_bus - 1)  # 确保索引为整数
    pp.create_sgen(
        net,
        bus=bus_idx,
        p_mw=float(config['PV_CAPACITY']),  # 确保功率为浮点数
        q_mvar=0.0,
        name=f'PV_{pv_bus}',
        type='PV',
        in_service=True,
        scaling=1.0
    )

# 在指定节点添加SVC（使用静态发电机模拟）
svc_nodes = config['SVC_NODES']
for svc_bus in svc_nodes:
    bus_idx = int(svc_bus - 1)  # 确保索引为整数
    pp.create_sgen(
        net,
        bus=bus_idx,
        p_mw=0.0,
        q_mvar=0.0,  # 初始无功功率为0
        name=f'SVC_{svc_bus}',
        type='SVC',
        min_q_mvar=float(config['SVC_MIN_MVAR']),  # SVC容性范围
        max_q_mvar=float(config['SVC_MAX_MVAR']),   # SVC感性范围
        controllable=True,
        in_service=True,
        scaling=1.0
    )

# 性能跟踪器
class PerformanceTracker:
    def __init__(self):
        self.episode_rewards = []
        self.voltage_violations = []
        self.avg_voltages = []
        self.pv_curtailments = []
        self.svc_adjustments = []
        self.network_losses = []  # 添加网损跟踪
    
    def add_episode_data(self, episode_reward, voltage_violations, avg_voltage,
                        pv_curtailments, svc_adjustments, network_loss=0.0):
        self.episode_rewards.append(episode_reward)
        self.voltage_violations.append(voltage_violations)
        self.avg_voltages.append(avg_voltage)
        self.pv_curtailments.append(pv_curtailments)
        self.svc_adjustments.append(svc_adjustments)
        self.network_losses.append(network_loss)  # 添加网损数据
    
    def plot_performance(self, save_path=None):
        # 将每个性能指标分别绘制为单独的图表
        save_dir = os.path.dirname(save_path) if save_path else '.'
        base_name = os.path.splitext(os.path.basename(save_path))[0] if save_path else 'performance'
        
        # 绘制累积奖励
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_rewards, 'b-', label='奖励值', linewidth=2)
        plt.title('累积奖励', fontsize=14)
        plt.xlabel('训练回合', fontsize=12)
        plt.ylabel('奖励值', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=10)
        reward_path = os.path.join(save_dir, f'{base_name}_rewards.png')
        plt.savefig(reward_path, dpi=600, bbox_inches='tight', format='png')
        plt.close()
        
        # 绘制电压越限次数
        plt.figure(figsize=(10, 6))
        plt.plot(self.voltage_violations, 'r-', label='越限次数', linewidth=2)
        plt.title('电压越限统计', fontsize=14)
        plt.xlabel('训练回合', fontsize=12)
        plt.ylabel('越限次数', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=10)
        violation_path = os.path.join(save_dir, f'{base_name}_violations.png')
        plt.savefig(violation_path, dpi=600, bbox_inches='tight', format='png')
        plt.close()
        
        # 绘制平均电压
        plt.figure(figsize=(10, 6))
        plt.plot(self.avg_voltages, 'g-', label='平均电压', linewidth=2)
        plt.axhline(y=1.0, color='r', linestyle='--', label='标称电压')
        plt.title('母线平均电压', fontsize=14)
        plt.xlabel('训练回合', fontsize=12)
        plt.ylabel('电压 (标幺值)', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=10)
        voltage_path = os.path.join(save_dir, f'{base_name}_voltages.png')
        plt.savefig(voltage_path, dpi=600, bbox_inches='tight', format='png')
        plt.close()
        
        # 绘制光伏调节次数
        plt.figure(figsize=(10, 6))
        plt.plot(self.pv_curtailments, 'm-', label='调节次数', linewidth=2)
        plt.title('光伏功率调节统计', fontsize=14)
        plt.xlabel('训练回合', fontsize=12)
        plt.ylabel('调节次数', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=10)
        pv_path = os.path.join(save_dir, f'{base_name}_pv_curtailments.png')
        plt.savefig(pv_path, dpi=600, bbox_inches='tight', format='png')
        plt.close()
        
        # 绘制SVC调节次数
        plt.figure(figsize=(10, 6))
        plt.plot(self.svc_adjustments, 'c-', label='调节次数', linewidth=2)
        plt.title('SVC无功调节统计', fontsize=14)
        plt.xlabel('训练回合', fontsize=12)
        plt.ylabel('调节次数', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=10)
        svc_path = os.path.join(save_dir, f'{base_name}_svc_adjustments.png')
        plt.savefig(svc_path, dpi=600, bbox_inches='tight', format='png')
        plt.close()
        
        # 绘制网损
        plt.figure(figsize=(10, 6))
        plt.plot(self.network_losses, 'orange', label='网络损耗', linewidth=2)
        plt.title('配电网网损统计', fontsize=14)
        plt.xlabel('训练回合', fontsize=12)
        plt.ylabel('网损 (MW)', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=10)
        loss_path = os.path.join(save_dir, f'{base_name}_network_losses.png')
        plt.savefig(loss_path, dpi=600, bbox_inches='tight', format='png')
        plt.close()

    def plot_voltage_profile(self, net, save_path=None):
        voltages = net.res_bus.vm_pu.values
        bus_indices = range(1, len(voltages) + 1)

        plt.figure(figsize=(12, 6))
        plt.plot(bus_indices, voltages, 'b-', label='节点电压', linewidth=2)
        plt.plot(bus_indices, voltages, 'bo', markersize=6)

        # 添加上下限参考线
        plt.axhline(y=config['V_MAX'], color='r', linestyle='--', label='电压上限')
        plt.axhline(y=config['V_MIN'], color='r', linestyle='--', label='电压下限')
        plt.axhline(y=1.0, color='g', linestyle='--', label='标称电压')

        # 标记光伏节点
        pv_nodes = config['PV_NODES']
        for pv in pv_nodes:
            plt.plot(pv, voltages[pv-1], 'r^', markersize=12, label='光伏节点' if pv == pv_nodes[0] else '')

        # 标记SVC节点
        svc_nodes = config['SVC_NODES']
        for svc in svc_nodes:
            plt.plot(svc, voltages[svc-1], 'ys', markersize=12, label='SVC节点' if svc == svc_nodes[0] else '')

        plt.grid(True, linestyle='--', alpha=0.7)
        plt.xlabel('节点编号', fontsize=12)
        plt.ylabel('电压幅值 (标幺值)', fontsize=12)
        plt.title('33节点配电网电压分布图', fontsize=14, pad=10)
        plt.legend(loc='best', fontsize=10)
        plt.ylim(0.9, 1.1)

        # 添加网格和边框
        plt.grid(True, linestyle='--', alpha=0.7)
        ax = plt.gca()
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        if save_path:
            # 确保保存路径使用.png扩展名
            save_path = os.path.splitext(save_path)[0] + '.png'
            plt.savefig(save_path, dpi=600, bbox_inches='tight', format='png')
        plt.close()

# Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, action_dim),
            nn.Tanh()
        )
    
    def forward(self, state):
        return self.net(state)

# Critic网络
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.state_net = nn.Linear(state_dim, 400)
        self.action_net = nn.Linear(action_dim, 400)
        self.net = nn.Sequential(
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 1)
        )
        
    def forward(self, state, action):
        s = self.state_net(state)
        a = self.action_net(action)
        x = torch.relu(s + a)
        return self.net(x)

# DDPG智能体
class DDPGAgent:
    def __init__(self, state_dim, action_dim, action_bounds=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 初始化Actor网络
        self.actor = Actor(state_dim, action_dim).to(self.device)
        self.actor_target = Actor(state_dim, action_dim).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=config['LR'])
        
        # 初始化Critic网络
        self.critic = Critic(state_dim, action_dim).to(self.device)
        self.critic_target = Critic(state_dim, action_dim).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=config['LR'])
        
        self.memory = deque(maxlen=config['MEMORY_SIZE'])
        self.batch_size = config['BATCH_SIZE']
        self.gamma = config['GAMMA']
        self.tau = 0.001  # 软更新参数
        
        # 动作范围
        self.action_bounds = action_bounds if action_bounds else {
            'pv': {'min': 0.0, 'max': config['PV_CAPACITY']},
            'svc': {'min': config['SVC_MIN_MVAR'], 'max': config['SVC_MAX_MVAR']}
        }
        
        # 探索噪声
        self.noise_scale = 0.1
        
    def select_action(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            action = self.actor(state).cpu().numpy().squeeze()
            
            # 添加探索噪声
            noise = np.random.normal(0, self.noise_scale, size=action.shape)
            action = action + noise
            
            # 将动作映射到实际控制范围
            pv_actions = action[:6]  # 前6个动作用于PV控制
            svc_actions = action[6:]  # 后2个动作用于SVC控制
            
            pv_actions = np.clip(pv_actions, 0, 1) * self.action_bounds['pv']['max']
            svc_actions = np.clip(svc_actions, -1, 1) * max(abs(self.action_bounds['svc']['min']), 
                                                          abs(self.action_bounds['svc']['max']))
            
            return np.concatenate([pv_actions, svc_actions])
    
    def train(self):
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
        
        # 更新Critic
        next_actions = self.actor_target(next_states)
        target_q = self.critic_target(next_states, next_actions)
        target_q = (rewards + (1 - dones) * self.gamma * target_q).detach()
        
        current_q = self.critic(states, actions)
        critic_loss = nn.MSELoss()(current_q, target_q)
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # 更新Actor
        actor_loss = -self.critic(states, self.actor(states)).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # 软更新目标网络
        self._soft_update(self.actor_target, self.actor)
        self._soft_update(self.critic_target, self.critic)
    
    def _soft_update(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)

# 创建性能跟踪器
performance_tracker = PerformanceTracker()

# 获取状态
def get_state(net):
    try:
        # 打印变压器状态
        print("\n运行潮流前变压器状态:")
        print(net.trafo[["hv_bus", "lv_bus", "tap_pos"]].to_string())
        print("\n数据类型:")
        print(net.trafo.dtypes[["hv_bus", "lv_bus", "tap_pos"]])
        
        # 确保网络格式正确
        pp.convert_format(net)
        
        # 运行潮流计算
        pp.runpp(net, calculate_voltage_angles=True, init='flat', numba=False, max_iteration=100)
        
        # 获取节点电压并确保为浮点数组
        voltages = net.res_bus['vm_pu'].to_numpy().astype(np.float64)
        
        # 获取PV和SVC的状态，确保为浮点数
        pv_p = np.array([float(net.sgen.loc[i, 'p_mw']) for i in range(len(config['PV_NODES']))], dtype=np.float64)
        svc_q = np.array([float(net.sgen.loc[i, 'q_mvar']) 
                         for i in range(len(config['PV_NODES']), len(config['PV_NODES']) + len(config['SVC_NODES']))],
                        dtype=np.float64)
        
        # 将所有状态组合成一个向量，确保为浮点数组
        state = np.concatenate([voltages, pv_p, svc_q]).astype(np.float64)
        
        return state
    except Exception as e:
        print(f"\n潮流计算错误: {str(e)}")
        print("错误发生时的变压器状态:")
        print(net.trafo[["hv_bus", "lv_bus", "tap_pos"]].to_string())
        print("\n数据类型:")
        print(net.trafo.dtypes[["hv_bus", "lv_bus", "tap_pos"]])
        raise

# 计算网络损耗
def calculate_network_loss(net):
    """计算配电网的网络损耗（有功损耗）"""
    try:
        # 确保已经运行潮流计算
        if not hasattr(net, 'res_line') or len(net.res_line) == 0:
            pp.runpp(net, numba=False)
        
        # 计算线路损耗
        line_losses = net.res_line.pl_mw.sum()
        
        # 计算变压器损耗
        trafo_losses = 0.0
        if hasattr(net, 'res_trafo') and len(net.res_trafo) > 0:
            trafo_losses = net.res_trafo.pl_mw.sum()
        
        # 总网络损耗
        total_losses = line_losses + trafo_losses
        
        return total_losses
    except Exception as e:
        print(f"计算网络损耗时出错: {str(e)}")
        return 0.0

# 计算奖励
def compute_reward(state, prev_state=None, net=None):
    # 确保输入为numpy数组且为float64类型
    state = np.asarray(state, dtype=np.float64)
    if prev_state is not None:
        prev_state = np.asarray(prev_state, dtype=np.float64)
    
    # 只取电压部分（前33个元素）计算偏差和越限
    voltages = state[:33]
    voltage_dev = np.sum((voltages - 1.0) ** 2, dtype=np.float64)
    
    # 计算电压越限
    v_max = float(config['V_MAX'])
    v_min = float(config['V_MIN'])
    voltage_violation = float(np.sum(voltages > v_max) + np.sum(voltages < v_min))
    
    # 计算网络损耗惩罚（如果提供了网络模型）
    loss_penalty = 0.0
    network_loss = 0.0
    if net is not None:
        try:
            network_loss = calculate_network_loss(net)
            # 增加网损惩罚权重，使其成为主要评判标准
            loss_penalty = 5.0 * network_loss  # 提高网损惩罚权重
        except Exception as e:
            print(f"计算网损时出错: {str(e)}")
            # 如果计算网损出错，给予一个大的惩罚
            loss_penalty = 100.0
    
    # 计算总奖励，以网损为主要评判标准
    reward = float(-voltage_dev - 10.0 * voltage_violation - loss_penalty)
    return reward, network_loss  # 返回奖励和网损值

# 训练循环
state_dim = 41  # 33个节点的电压值 + 6个PV功率 + 2个SVC功率
action_dim = 8  # 6个PV和2个SVC的连续控制

agent = DDPGAgent(state_dim, action_dim)

for episode in range(config['EPISODES']):
    net = case33bw()
    
    # 确保所有母线索引为整数类型
    net.bus.index = net.bus.index.astype(int)
    net.load.bus = net.load.bus.astype(int)
    if len(net.gen) > 0:
        net.gen.bus = net.gen.bus.astype(int)
    net.ext_grid.bus = net.ext_grid.bus.astype(int)
    
    episode_reward = 0
    voltage_violations = 0
    pv_curtailments = 0
    svc_adjustments = 0
    prev_state = None
    
    # 创建变压器
    pp.create_transformer_from_parameters(net,
        hv_bus=0,  # 高压侧连接到节点0
        lv_bus=1,  # 低压侧连接到节点1
        sn_mva=10.0,  # 额定容量
        vn_hv_kv=12.66,  # 高压侧额定电压
        vn_lv_kv=12.66,  # 低压侧额定电压
        vk_percent=6.0,  # 短路阻抗
        vkr_percent=1.0,  # 短路损耗
        pfe_kw=10.0,  # 铁损
        i0_percent=0.1,  # 空载电流
        tap_neutral=0,  # 中性档位
        tap_min=-2,  # 最小档位
        tap_max=2,  # 最大档位
        tap_step_percent=2.5,  # 每档调节百分比
        tap_side='hv',  # 调压绕组在高压侧
        tap_phase_shifter=False,  # 不是移相变压器
        name='Trafo1',
        tap_pos=0,  # 初始档位
        in_service=True  # 投运状态
    )
    
    # 打印网络状态信息
    print(f"\nEpisode {episode} 初始化状态:")
    print("变压器信息:")
    print(net.trafo[["hv_bus", "lv_bus", "tap_pos", "tap_neutral", "tap_min", "tap_max", "tap_step_percent"]].to_string())
    print("\n变压器数据类型:")
    print(net.trafo.dtypes[["hv_bus", "lv_bus", "tap_pos"]])
    
    # 在指定节点添加光伏
    for pv_bus in config['PV_NODES']:
        bus_idx = int(pv_bus - 1)  # 确保索引为整数
        pp.create_sgen(
            net,
            bus=bus_idx,
            p_mw=float(config['PV_CAPACITY']),  # 确保功率为浮点数
            q_mvar=0.0,
            name=f'PV_{pv_bus}',
            type='PV',
            in_service=True,
            scaling=1.0
        )

    # 在指定节点添加SVC（使用静态发电机模拟）
    for svc_bus in config['SVC_NODES']:
        bus_idx = int(svc_bus - 1)  # 确保索引为整数
        pp.create_sgen(
            net,
            bus=bus_idx,
            p_mw=0.0,
            q_mvar=0.0,  # 初始无功功率为0
            name=f'SVC_{svc_bus}',
            type='SVC',
            controllable=True,
            in_service=True,
            scaling=1.0
        )
    
    # 获取长时间尺度的控制结果（假设从LONG_DDPG.py获取）
    tap_pos = 0  # 这里应该从LONG_DDPG获取实际值
    scb_states = [False, False]  # 这里应该从LONG_DDPG获取实际值
    
    # 应用长时间尺度的控制结果
    # 设置变压器抽头位置
    if len(net.trafo.index) > 0:  # 确保存在变压器
        tap_min = net.trafo.loc[net.trafo.index[0], 'tap_min']
        tap_max = net.trafo.loc[net.trafo.index[0], 'tap_max']
        tap_pos = int(np.clip(tap_pos, tap_min, tap_max))  # 确保在允许范围内
        for idx in net.trafo.index:
            net.trafo.loc[idx, 'tap_pos'] = tap_pos  # 使用loc设置每个变压器的抽头位置
    
    # 设置分路电容器状态（如果存在shunt）
    if len(net.shunt) > 0 and len(net.shunt.index) >= 2:
        net.shunt.loc[net.shunt.index[0], 'in_service'] = scb_states[0]
        net.shunt.loc[net.shunt.index[1], 'in_service'] = scb_states[1]
    
    try:
        state = get_state(net)
    except Exception as e:
        print(f"初始潮流计算错误：{e}")
        print("跳过当前episode并记录网损为最大值")
        # 记录本episode的性能指标，将网损设为最大值表示失败
        performance_tracker.add_episode_data(
            -1000.0,  # 极低的奖励值
            100,      # 高电压越限次数
            0.0,      # 无效的平均电压
            0,        # 无PV调节
            0,        # 无SVC调节
            10.0      # 高网损值表示失败
        )
        continue  # 如果出现错误，跳过当前episode
    
    for t in range(48):  # 24小时，每30分钟一个时间步
        action = agent.select_action(state)
        
        # 解析动作（前6个为PV控制，后2个为SVC控制）
        pv_actions = action[:6]
        svc_actions = action[6:]
        
        # 应用PV控制动作
        for i, pv_action in enumerate(pv_actions):
            current_p = float(net.sgen.loc[i, 'p_mw'])
            if abs(pv_action - current_p) > 1e-6:  # 使用小阈值比较浮点数
                net.sgen.loc[i, 'p_mw'] = float(pv_action)
                pv_curtailments += 1
        
        # 应用SVC控制动作
        for i, svc_action in enumerate(svc_actions):
            svc_idx = len(config['PV_NODES']) + i  # SVC在sgen中的索引（PV之后）
            current_q = float(net.sgen.loc[svc_idx, 'q_mvar'])
            if abs(svc_action - current_q) > 1e-6:  # 使用小阈值比较浮点数
                # 确保SVC的无功输出在限制范围内
                q_mvar = np.clip(float(svc_action), 
                                float(config['SVC_MIN_MVAR']), 
                                float(config['SVC_MAX_MVAR']))
                net.sgen.loc[svc_idx, 'q_mvar'] = q_mvar
                svc_adjustments += 1
        
        # 获取新状态和奖励
        try:
            next_state = get_state(net)
            
            # 计算奖励（包含网损惩罚）和网损
            reward, network_loss = compute_reward(next_state, state, net)
            episode_reward += reward
            
            # 记录当前时间步的网损
            current_network_loss = network_loss
        except Exception as e:
            print(f"时间步 {t} 潮流计算错误：{e}")
            # 设置一个大的惩罚作为奖励
            reward = -100.0
            next_state = state  # 使用上一个状态
            current_network_loss = 10.0  # 设置一个高网损值
            episode_reward += reward
        
        # 统计电压越限
        try:
            voltage_violations += np.sum((next_state[:33] > config['V_MAX']) | (next_state[:33] < config['V_MIN']))
        except Exception as e:
            print(f"计算电压越限时出错: {str(e)}")
            voltage_violations += 10  # 如果出错，假设有10个越限
        
        # 存储经验
        done = (t == 47)
        # 只存储奖励值到经验回放池
        agent.memory.append((state, action, reward, next_state, done))
        
        # 如果潮流计算出错，打印警告并继续
        if np.isnan(reward) or np.isinf(reward):
            print(f"警告: 时间步 {t} 的奖励值异常: {reward}")
            continue
        
        # 训练智能体
        agent.train()
        
        state = next_state
        prev_state = state
        
        # 每6个时间步（3小时）打印一次状态信息
        if t % 6 == 0:
            print(f"Episode {episode}, Time Step {t}:")
            print(f"  Reward = {reward:.4f}")
            print(f"  Avg Voltage = {np.mean(state[:33]):.4f}")
            print(f"  Network Loss = {current_network_loss:.4f} MW")
            print(f"  Voltage Violations = {np.sum((next_state[:33] > config['V_MAX']) | (next_state[:33] < config['V_MIN']))}")
            print(f"  PV Curtailments = {pv_curtailments}")
            print(f"  SVC Adjustments = {svc_adjustments}\n")
    
    # 计算最终网络损耗
    try:
        # 确保运行最后一次潮流计算
        pp.runpp(net, numba=False)
        final_network_loss = calculate_network_loss(net)
        print(f"\n最终网络损耗计算成功: {final_network_loss:.4f} MW")
    except Exception as e:
        print(f"\n计算最终网络损耗时出错: {str(e)}")
        final_network_loss = 0.0  # 如果计算出错，设置为0
    
    # 记录本episode的性能指标
    performance_tracker.add_episode_data(
        episode_reward,
        voltage_violations,
        float(np.mean(state[:33])),
        pv_curtailments,
        svc_adjustments,
        final_network_loss
    )
    
    print(f"Episode {episode} 总结:")
    print(f"  网络损耗 = {final_network_loss:.4f} MW (主要评判标准)")
    print(f"  总奖励 = {episode_reward:.4f}")
    print(f"  电压越限次数 = {voltage_violations}")
    print(f"  PV调节次数 = {pv_curtailments}")
    print(f"  SVC调节次数 = {svc_adjustments}")
    
    # 每100个episode保存模型和性能图
    if (episode + 1) % config['SAVE_INTERVAL'] == 0:
        save_dir = config['SAVE_DIR']
        os.makedirs(save_dir, exist_ok=True)
        
        # 保存Actor和Critic网络
        model_path = os.path.join(save_dir, f'ddpg_model_episode_{episode+1}')
        torch.save({
            'actor': agent.actor.state_dict(),
            'critic': agent.critic.state_dict(),
            'config': config,
            'performance': {
                'rewards': performance_tracker.episode_rewards,
                'violations': performance_tracker.voltage_violations,
                'voltages': performance_tracker.avg_voltages,
                'pv_curtailments': performance_tracker.pv_curtailments,
                'svc_adjustments': performance_tracker.svc_adjustments,
                'network_losses': performance_tracker.network_losses
            }
        }, f"{model_path}.pth")

        # 绘制并保存性能指标图
        performance_path = os.path.join(save_dir, f'performance_episode_{episode+1}')
        performance_tracker.plot_performance(performance_path)
        
        # 运行潮流计算以获取最新状态
        pp.runpp(net, numba=False)
        
        # 绘制并保存电压分布图
        voltage_path = os.path.join(save_dir, f'voltage_profile_episode_{episode+1}')
        performance_tracker.plot_voltage_profile(net, voltage_path)
        
        print(f"\nEpisode {episode+1} 的模型和性能指标已保存：")
        print(f"模型保存路径：{model_path}.pth")
        print(f"性能指标图保存路径：{performance_path}")
        print(f"电压分布图保存路径：{voltage_path}")
        print(f"最终网络损耗：{final_network_loss:.4f} MW (主要评判标准)")
        
        # 如果网损较低，打印成功信息
        if final_network_loss < 0.2:  # 假设0.2MW是一个较好的网损值
            print(f"\n恭喜！本次训练取得了较低的网络损耗 {final_network_loss:.4f} MW")
        else:
            print(f"\n网络损耗仍有优化空间，当前值: {final_network_loss:.4f} MW")