"""
使用DP (Dynamic Programming) 动态规划算法进行钻井参数优化设计
第十五届中国石油工程设计大赛方案设计类赛题 - 优化版本
"""

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
import warnings
import logging

warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
np.random.seed(42)
torch.manual_seed(42)


# ==================== 1. 数据生成 ====================

class DrillingDataGenerator:
    def __init__(self):
        self.formations = [
            {'name': '表层', 'depth_start': 0, 'depth_end': 500, 'drillability': 0.8, 'abrasiveness': 0.3},
            {'name': '泥岩层', 'depth_start': 500, 'depth_end': 1500, 'drillability': 0.6, 'abrasiveness': 0.5},
            {'name': '砂岩层', 'depth_start': 1500, 'depth_end': 2500, 'drillability': 0.5, 'abrasiveness': 0.7},
            {'name': '石灰岩层', 'depth_start': 2500, 'depth_end': 3500, 'drillability': 0.3, 'abrasiveness': 0.8},
            {'name': '目标层', 'depth_start': 3500, 'depth_end': 4000, 'drillability': 0.4, 'abrasiveness': 0.6},
        ]
        self.bit_types = [
            {'name': 'PDC钻头', 'cost': 50000, 'max_wob': 20, 'max_rpm': 300, 'efficiency': 1.2, 'life_factor': 1.0},
            {'name': '牙轮钻头', 'cost': 30000, 'max_wob': 25, 'max_rpm': 200, 'efficiency': 0.9, 'life_factor': 0.8},
            {'name': '金刚石钻头', 'cost': 80000, 'max_wob': 15, 'max_rpm': 250, 'efficiency': 1.5, 'life_factor': 1.3},
        ]

    def get_formation_at_depth(self, depth: float) -> Dict:
        for form in self.formations:
            if form['depth_start'] <= depth < form['depth_end']:
                return form
        return self.formations[-1]

    def calculate_rop(self, wob: float, rpm: float, flow_rate: float,
                      depth: float, bit_type_idx: int) -> float:
        formation = self.get_formation_at_depth(depth)
        bit = self.bit_types[bit_type_idx]
        K = 0.5 * formation['drillability'] * bit['efficiency']
        a, b, c = 0.6, 0.5, 0.3
        depth_factor = np.exp(-depth / 2000)
        wob_norm = np.clip(wob / 20.0, 0.1, 2.0)
        rpm_norm = np.clip(rpm / 200.0, 0.1, 2.0)
        flow_norm = np.clip(flow_rate / 35.0, 0.1, 2.0)
        rop = K * (wob_norm ** a) * (rpm_norm ** b) * (flow_norm ** c) * depth_factor * 30
        return np.clip(rop, 0.5, 50.0)

    def calculate_bit_wear(self, wob: float, rpm: float, depth: float,
                           bit_type_idx: int, hours: float) -> float:
        formation = self.get_formation_at_depth(depth)
        bit = self.bit_types[bit_type_idx]
        wear_rate = (wob / 20) * (rpm / 200) * formation['abrasiveness'] / bit['life_factor']
        return wear_rate * hours * 0.01

    def calculate_drilling_cost(self, rop: float, bit_cost: float,
                                duration: float, rig_day_rate: float = 50000) -> float:
        footage = rop * duration
        if footage <= 0:
            return float('inf')
        rig_cost = rig_day_rate * (duration / 24)
        total_cost = rig_cost + bit_cost
        return total_cost / footage


# ==================== 2. 状态定义（关键优化）====================

@dataclass(frozen=True)
class DrillingState:
    """优化：添加current_bit_idx跟踪当前钻头"""
    depth: float
    bit_wear: float
    remaining_budget: float
    current_bit_idx: int  # -1表示未开始

    def is_terminal(self, target_depth: float) -> bool:
        return self.depth >= target_depth or self.remaining_budget <= 0


@dataclass
class DrillingAction:
    wob: float
    rpm: float
    flow_rate: float
    bit_type_idx: int
    duration: float


# ==================== 3. 神经网络预测器 ====================

class ROPPredictor(nn.Module):
    def __init__(self, input_dim: int = 6, hidden_dim: int = 64):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

    def predict(self, wob, rpm, flow, depth, drillability, efficiency):
        features = torch.FloatTensor([[wob, rpm, flow, depth, drillability, efficiency]])
        with torch.no_grad():
            return self.forward(features).item()

    def train_model(self, data_gen: DrillingDataGenerator, epochs: int = 50):
        logger.info("训练ROP预测模型...")
        X, y = [], []
        for _ in range(3000):
            wob = np.random.uniform(5, 25)
            rpm = np.random.uniform(100, 300)
            flow = np.random.uniform(20, 50)
            depth = np.random.uniform(0, 4000)
            bit_idx = np.random.randint(0, 3)
            formation = data_gen.get_formation_at_depth(depth)
            bit = data_gen.bit_types[bit_idx]
            X.append([wob, rpm, flow, depth, formation['drillability'], bit['efficiency']])
            y.append(data_gen.calculate_rop(wob, rpm, flow, depth, bit_idx))

        X, y = torch.FloatTensor(X), torch.FloatTensor(y).unsqueeze(1)
        optimizer = torch.optim.Adam(self.parameters(), lr=0.005)
        criterion = nn.MSELoss()

        for epoch in range(epochs):
            self.train()
            optimizer.zero_grad()
            loss = criterion(self.forward(X), y)
            loss.backward()
            optimizer.step()
            if (epoch + 1) % 20 == 0:
                logger.info(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
        self.eval()


# ==================== 4. DP优化器（核心修复）====================

class DynamicProgrammingOptimizer:
    def __init__(self, data_gen: DrillingDataGenerator, target_depth: float = 4000,
                 total_budget: float = 2000000, use_nn_approx: bool = False):
        self.data_gen = data_gen
        self.target_depth = target_depth
        self.total_budget = total_budget
        self.use_nn_approx = use_nn_approx
        self.rop_predictor = None

        # 优化动作空间
        self.wob_options = np.linspace(8, 22, 4)
        self.rpm_options = np.linspace(120, 280, 4)
        self.flow_options = np.linspace(25, 45, 3)
        self.duration_options = [12, 24]

        # 离散化参数
        self.depth_step = 100
        self.wear_step = 0.1
        self.budget_step = 100000

        self.value_table: Dict[Tuple, float] = {}
        self.policy_table: Dict[Tuple, DrillingAction] = {}

        self._build_state_mappings()

    def _build_state_mappings(self):
        self.depth_states = np.arange(0, self.target_depth + self.depth_step, self.depth_step)
        self.wear_states = np.linspace(0, 1, int(1 / self.wear_step) + 1)
        self.budget_states = np.arange(0, self.total_budget + self.budget_step, self.budget_step)

    def get_nearest_state_key(self, state: DrillingState) -> Tuple:
        """O(1)状态查找"""
        d_idx = min(int(state.depth / self.depth_step), len(self.depth_states) - 1)
        w_idx = min(int(state.bit_wear / self.wear_step), len(self.wear_states) - 1)
        b_idx = min(int(state.remaining_budget / self.budget_step), len(self.budget_states) - 1)
        return (d_idx, w_idx, b_idx, state.current_bit_idx)

    def get_valid_actions(self, state: DrillingState) -> List[DrillingAction]:
        """
        关键修复：强制钻头更换逻辑
        """
        actions = []
        # 判断是否需要更换钻头
        need_change = (state.current_bit_idx == -1 or state.bit_wear >= 0.8)

        for bit_idx, bit in enumerate(self.data_gen.bit_types):
            # 关键修复：必须更换不同类型钻头
            if need_change and bit_idx == state.current_bit_idx:
                continue

            for wob in self.wob_options:
                if wob > bit['max_wob']:
                    continue
                for rpm in self.rpm_options:
                    if rpm > bit['max_rpm']:
                        continue
                    for flow in self.flow_options:
                        for duration in self.duration_options:
                            # 预算检查
                            est_cost = (50000 * (duration / 24) + bit['cost'])
                            if est_cost > state.remaining_budget * 1.2:
                                continue
                            actions.append(DrillingAction(wob, rpm, flow, bit_idx, duration))
        return actions

    def transition(self, state: DrillingState, action: DrillingAction,
                   use_nn: bool = False) -> Tuple[DrillingState, float]:
        """
        关键修复：钻头更换时重置磨损
        """
        bit_changed = (action.bit_type_idx != state.current_bit_idx)

        # 计算ROP
        if use_nn and self.rop_predictor is not None:
            formation = self.data_gen.get_formation_at_depth(state.depth)
            bit = self.data_gen.bit_types[action.bit_type_idx]
            rop = self.rop_predictor.predict(
                action.wob, action.rpm, action.flow_rate,
                state.depth, formation['drillability'], bit['efficiency']
            )
            rop = np.clip(rop, 0.5, 50.0)
        else:
            rop = self.data_gen.calculate_rop(
                action.wob, action.rpm, action.flow_rate,
                state.depth, action.bit_type_idx
            )

        footage = rop * action.duration
        new_depth = min(state.depth + footage, self.target_depth)

        # 关键修复：更换钻头时磨损重置
        wear_increment = self.data_gen.calculate_bit_wear(
            action.wob, action.rpm, state.depth,
            action.bit_type_idx, action.duration
        )

        if bit_changed:
            new_bit_wear = wear_increment  # 新钻头从零开始
        else:
            new_bit_wear = min(state.bit_wear + wear_increment, 1.0)

        # 计算成本
        bit = self.data_gen.bit_types[action.bit_type_idx]
        cost_per_meter = self.data_gen.calculate_drilling_cost(rop, bit['cost'], action.duration)
        actual_cost = cost_per_meter * footage

        if bit_changed:
            actual_cost += bit['cost']

        new_budget = state.remaining_budget - actual_cost

        next_state = DrillingState(
            depth=new_depth,
            bit_wear=new_bit_wear,
            remaining_budget=new_budget,
            current_bit_idx=action.bit_type_idx
        )

        # 奖励
        reward = footage - actual_cost * 0.001
        if new_depth >= self.target_depth:
            reward += 5000
            if new_budget > 0:
                reward += new_budget * 0.0001

        return next_state, reward

    def value_iteration(self, max_iterations: int = 200, tolerance: float = 1e-4) -> None:
        """统一状态空间的值迭代"""
        logger.info("开始值迭代...")
        n_depth, n_wear = len(self.depth_states), len(self.wear_states)
        n_budget, n_bits = len(self.budget_states), len(self.data_gen.bit_types)

        total_states = n_depth * n_wear * n_budget * (n_bits + 1)
        logger.info(f"状态空间: {total_states}")

        # 初始化
        for d_idx in range(n_depth):
            for w_idx in range(n_wear):
                for b_idx in range(n_budget):
                    for bit_idx in range(-1, n_bits):
                        key = (d_idx, w_idx, b_idx, bit_idx)
                        self.value_table[key] = 0 if d_idx >= n_depth - 1 else -1000

        gamma = 0.95

        for iteration in range(max_iterations):
            delta = 0
            for d_idx in range(n_depth - 1):
                for w_idx in range(n_wear):
                    for b_idx in range(n_budget):
                        for bit_idx in range(-1, n_bits):
                            state_key = (d_idx, w_idx, b_idx, bit_idx)
                            state = DrillingState(
                                depth=self.depth_states[d_idx],
                                bit_wear=self.wear_states[w_idx],
                                remaining_budget=self.budget_states[b_idx],
                                current_bit_idx=bit_idx
                            )

                            actions = self.get_valid_actions(state)
                            if not actions:
                                continue

                            best_value, best_action = -float('inf'), None
                            for action in actions:
                                next_state, reward = self.transition(state, action)
                                next_key = self.get_nearest_state_key(next_state)
                                next_value = self.value_table.get(next_key, 0)
                                total_value = reward + gamma * next_value

                                if total_value > best_value:
                                    best_value = total_value
                                    best_action = action

                            if best_value > -float('inf'):
                                old_value = self.value_table.get(state_key, -1000)
                                self.value_table[state_key] = best_value
                                self.policy_table[state_key] = best_action
                                delta = max(delta, abs(best_value - old_value))

            if iteration % 20 == 0:
                logger.info(f"迭代 {iteration}: delta={delta:.4f}")

            if delta < tolerance:
                logger.info(f"收敛于第{iteration + 1}次迭代")
                break

        logger.info("值迭代完成")

    def extract_policy(self, initial_state: DrillingState) -> List[DrillingAction]:
        """提取策略路径"""
        logger.info("提取策略...")
        policy_path = []
        current_state = initial_state
        max_steps = 200
        step = 0

        while not current_state.is_terminal(self.target_depth) and step < max_steps:
            state_key = self.get_nearest_state_key(current_state)
            action = self.policy_table.get(state_key)

            if action is None:
                logger.warning(f"步骤{step}: 无策略，使用启发式")
                actions = self.get_valid_actions(current_state)
                if not actions:
                    break
                action = max(actions, key=lambda a: self.data_gen.bit_types[a.bit_type_idx]['efficiency'])

            policy_path.append(action)
            current_state, _ = self.transition(current_state, action)
            step += 1

        logger.info(f"策略提取完成: {len(policy_path)}步, 深度{current_state.depth:.0f}m")
        return policy_path


# ==================== 5. 主程序 ====================

def main():
    print("=" * 70)
    print("DP动态规划钻井参数优化系统 - 优化版本")
    print("第十五届中国石油工程设计大赛")
    print("=" * 70)

    TARGET_DEPTH = 4000
    TOTAL_BUDGET = 2000000
    USE_NEURAL_NETWORK = False

    try:
        logger.info("[1/5] 初始化...")
        data_gen = DrillingDataGenerator()

        rop_predictor = None
        if USE_NEURAL_NETWORK:
            logger.info("[2/5] 训练神经网络...")
            rop_predictor = ROPPredictor()
            rop_predictor.train_model(data_gen, epochs=100)
        else:
            logger.info("[2/5] 使用物理模型")

        logger.info("[3/5] 初始化优化器...")
        optimizer = DynamicProgrammingOptimizer(
            data_gen, target_depth=TARGET_DEPTH,
            total_budget=TOTAL_BUDGET, use_nn_approx=USE_NEURAL_NETWORK
        )
        optimizer.rop_predictor = rop_predictor

        logger.info("[4/5] 执行优化...")
        optimizer.value_iteration(max_iterations=100)

        logger.info("[5/5] 模拟执行...")
        initial_state = DrillingState(0, 0, TOTAL_BUDGET, -1)
        policy_path = optimizer.extract_policy(initial_state)

        # 详细模拟
        current_state = initial_state
        simulation_results = []

        for i, action in enumerate(policy_path):
            next_state, reward = optimizer.transition(current_state, action)
            rop = data_gen.calculate_rop(
                action.wob, action.rpm, action.flow_rate,
                current_state.depth, action.bit_type_idx
            )
            footage = rop * action.duration
            bit = data_gen.bit_types[action.bit_type_idx]
            cost = data_gen.calculate_drilling_cost(rop, bit['cost'], action.duration) * footage

            simulation_results.append({
                'step': i + 1,
                'depth_start': current_state.depth,
                'depth_end': next_state.depth,
                'wob': action.wob,
                'rpm': action.rpm,
                'flow_rate': action.flow_rate,
                'bit_type': bit['name'],
                'bit_changed': (action.bit_type_idx != current_state.current_bit_idx),
                'duration': action.duration,
                'rop': rop,
                'footage': footage,
                'cost': cost,
                'bit_wear': next_state.bit_wear,
                'remaining_budget': next_state.remaining_budget
            })

            current_state = next_state
            if current_state.depth >= TARGET_DEPTH:
                break

        # 输出结果
        total_cost = TOTAL_BUDGET - current_state.remaining_budget
        total_time = sum(r['duration'] for r in simulation_results)
        avg_rop = current_state.depth / total_time if total_time > 0 else 0

        print(f"\n{'=' * 70}")
        print("优化结果汇总:")
        print(f"{'=' * 70}")
        print(f"最终深度: {current_state.depth:.0f}m / {TARGET_DEPTH}m")
        print(f"总成本: {total_cost:.0f}元 / {TOTAL_BUDGET}元")
        print(f"总时间: {total_time:.1f}小时 ({total_time / 24:.1f}天)")
        print(f"平均机械钻速: {avg_rop:.2f}m/h")
        print(f"钻头更换次数: {sum(r['bit_changed'] for r in simulation_results)}")

        df_results = pd.DataFrame(simulation_results)
        df_results.to_csv('optimized_drilling_plan.csv', index=False, encoding='utf-8-sig')
        print(f"\n详细方案已保存至: optimized_drilling_plan.csv")

        return optimizer, simulation_results, df_results

    except Exception as e:
        logger.error(f"运行出错: {e}")
        raise


if __name__ == "__main__":
    optimizer, results, df = main()