import numpy as np
import pandas as pd
import rasterio
from rasterio.transform import from_origin
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, BoundaryNorm

# =========================================================
# 0. 参数设置
# =========================================================
RES = 0.5
WEST, EAST = -180.0, 180.0
SOUTH, NORTH = -90.0, 90.0

WIDTH = int((EAST - WEST) / RES)    # 720
HEIGHT = int((NORTH - SOUTH) / RES) # 360

CRS = "EPSG:4326"
NODATA_FLOAT = -9999.0
NODATA_INT = 0

OUTPUT_CSV = "grassland_n2o_effect_0p5deg.csv"
OUTPUT_TIF_LNRR = "grassland_n2o_effect_lnrr_0p5deg.tif"
OUTPUT_TIF_PCT = "grassland_n2o_change_pct_0p5deg.tif"
OUTPUT_TIF_CLASS = "grassland_n2o_class8_0p5deg.tif"
OUTPUT_LEGEND = "arcgis_8class_color_table.csv"
OUTPUT_PNG = "global_n2o_change_pct_map.png"

transform = from_origin(WEST, NORTH, RES, RES)

# =========================================================
# 1. 构建 0.5° 全球格网中心点
# 注意：纬度按从北到南排列，以便直接写栅格
# =========================================================
lon_centers = np.arange(WEST + RES / 2, EAST, RES)          # -179.75 ... 179.75
lat_centers = np.arange(NORTH - RES / 2, SOUTH, -RES)       # 89.75 ... -89.75

lon2d, lat2d = np.meshgrid(lon_centers, lat_centers)
abs_lat = np.abs(lat2d)

# =========================================================
# 2. 工具函数
# =========================================================
def clip01(x):
    return np.clip(x, 0.0, 1.0)

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

# =========================================================
# 3. 近似全球草地掩膜（模拟图1橘色区域的大体空间格局）
# 说明：
# 这不是正式草地产品，只是趋势预测用掩膜。
# 若你有真实草地栅格，请用真实掩膜替换这里。
# =========================================================
grass_mask = np.ones_like(lat2d, dtype=bool)

# 排除极地
grass_mask &= (abs_lat <= 75)

# 排除典型湿润热带森林核心区
amazon = (lat2d >= -15) & (lat2d <= 8) & (lon2d >= -80) & (lon2d <= -48)
congo = (lat2d >= -8) & (lat2d <= 6) & (lon2d >= 10) & (lon2d <= 32)
se_asia = (lat2d >= -8) & (lat2d <= 20) & (lon2d >= 95) & (lon2d <= 135)
new_guinea = (lat2d >= -10) & (lat2d <= 2) & (lon2d >= 135) & (lon2d <= 155)

# 排除典型极端荒漠核心区（保留草原/稀疏草地外围）
sahara_core = (lat2d >= 18) & (lat2d <= 30) & (lon2d >= -15) & (lon2d <= 30)
arabia_core = (lat2d >= 16) & (lat2d <= 30) & (lon2d >= 35) & (lon2d <= 58)

grass_mask &= ~(amazon | congo | se_asia | new_guinea | sahara_core | arabia_core)

# 加一点“草地带”约束，让分布更接近草地而非全部陆地
belt_mask = (
    ((abs_lat >= 50) & (abs_lat <= 75)) |   # 高纬寒冷草地
    ((abs_lat >= 25) & (abs_lat < 50))  |   # 温带草地
    ((abs_lat >= 5)  & (abs_lat < 25))      # 热带/亚热带稀树草原
)
grass_mask &= belt_mask

# =========================================================
# 4. 环境控制因子：水热条件 & 冻融强度
# 为了保证“颜色过渡自然”，这里使用连续平滑函数
# =========================================================

# ---- 4.1 温度适宜度 temp_score ----
# 峰值大致在中低纬（草地活动、生物过程较活跃）
temp_score = 1.0 - ((abs_lat - 28.0) / 38.0) ** 2
temp_score = clip01(temp_score)

# 热带过湿林区外围略降，防止赤道带全部变成极强负值
equator_penalty = 0.18 * np.exp(-(abs_lat / 10.0) ** 2)
temp_score = clip01(temp_score - equator_penalty)

# 青藏高原降温修正
tibetan = (lat2d >= 27) & (lat2d <= 40) & (lon2d >= 75) & (lon2d <= 105)
temp_score = np.where(tibetan, temp_score - 0.18, temp_score)
temp_score = clip01(temp_score)

# ---- 4.2 水分适宜度 moisture_score ----
# 先构造平滑的纬向/经向变化
moisture_score = (
    0.56
    + 0.16 * np.cos(np.radians(lat2d * 1.4))
    + 0.10 * np.sin(np.radians((lon2d + 20.0) * 0.8))
)

# 区域修正：干旱区降低，湿润区稍提高
central_asia = (lat2d >= 35) & (lat2d <= 50) & (lon2d >= 55) & (lon2d <= 100)
west_na = (lat2d >= 30) & (lat2d <= 50) & (lon2d >= -125) & (lon2d <= -105)
patagonia = (lat2d >= -55) & (lat2d <= -35) & (lon2d >= -75) & (lon2d <= -60)
sahel = (lat2d >= 8) & (lat2d <= 18) & (lon2d >= -20) & (lon2d <= 35)
europe_wet = (lat2d >= 42) & (lat2d <= 58) & (lon2d >= -10) & (lon2d <= 30)
east_us = (lat2d >= 30) & (lat2d <= 48) & (lon2d >= -95) & (lon2d <= -75)

moisture_score = np.where(sahara_core, moisture_score - 0.35, moisture_score)
moisture_score = np.where(arabia_core, moisture_score - 0.28, moisture_score)
moisture_score = np.where(central_asia, moisture_score - 0.18, moisture_score)
moisture_score = np.where(west_na, moisture_score - 0.12, moisture_score)
moisture_score = np.where(patagonia, moisture_score - 0.10, moisture_score)
moisture_score = np.where(sahel, moisture_score - 0.08, moisture_score)
moisture_score = np.where(europe_wet, moisture_score + 0.08, moisture_score)
moisture_score = np.where(east_us, moisture_score + 0.06, moisture_score)

moisture_score = clip01(moisture_score)

# ---- 4.3 水热综合 hydrothermal_score ----
# 几何平均较稳健，避免单因子极值过度放大
hydrothermal_score = np.sqrt(temp_score * moisture_score)
hydrothermal_score = clip01(hydrothermal_score)

# ---- 4.4 冻融强度 freeze_strength ----
# 高纬冻融连续增强
freeze_strength = sigmoid((abs_lat - 55.0) / 4.5)

# 地形寒冷区增强：青藏高原、安第斯
andes = (lat2d >= -55) & (lat2d <= -10) & (lon2d >= -80) & (lon2d <= -65)
freeze_strength = np.where(tibetan, np.maximum(freeze_strength, 0.72), freeze_strength)
freeze_strength = np.where(andes, np.maximum(freeze_strength, 0.60), freeze_strength)

freeze_strength = clip01(freeze_strength)

# =========================================================
# 5. 按机制生成 lnRR
# 你的要求：
#   水热好 -> 负值（-0.25 ~ 0）
#   寒冷/冻融/条件差 -> 正值（0 ~ 0.15）
#
# 为了避免“硬边界”，这里采用平滑混合：
#   neg_component : 代表水热驱动的负效应
#   pos_component : 代表冻融/冷干胁迫驱动的正效应
#   transition_w  : 由环境条件连续控制
# =========================================================

# 负效应分量：水热越好，负值越大（越接近 -0.25）
neg_component = -0.25 * (hydrothermal_score ** 1.20)

# 正效应分量：冻融越强、且水热越差，正值越大（越接近 0.15）
cold_dry_index = clip01(0.70 * freeze_strength + 0.30 * (1.0 - hydrothermal_score))
pos_component = 0.15 * (cold_dry_index ** 1.10)

# 平滑过渡权重
transition_w = sigmoid((cold_dry_index - 0.52) * 10.0)

# 混合得到 lnRR
effect_lnrr = (1.0 - transition_w) * neg_component + transition_w * pos_component

# 严格限制在要求区间内
effect_lnrr = np.clip(effect_lnrr, -0.25, 0.15)

# 非草地区域设为 nodata
effect_lnrr = np.where(grass_mask, effect_lnrr, NODATA_FLOAT)

# =========================================================
# 6. 转换为 (e^lnRR - 1) * 100
# =========================================================
n2o_change_pct = np.where(
    grass_mask,
    (np.exp(effect_lnrr) - 1.0) * 100.0,
    NODATA_FLOAT
)

# 理论范围
pct_min = (np.exp(-0.25) - 1.0) * 100.0   # 约 -22.12
pct_max = (np.exp(0.15) - 1.0) * 100.0    # 约 16.18

# =========================================================
# 7. 八级分类
# 这里使用固定的 8 等间隔梯度，适合 ArcGIS 分级着色
# =========================================================
class_edges = np.linspace(pct_min, pct_max, 9)  # 9个边界 -> 8个等级

effect_class_id = np.full_like(lat2d, NODATA_INT, dtype=np.int16)

valid_vals = grass_mask & (n2o_change_pct != NODATA_FLOAT)
effect_class_id[valid_vals] = np.digitize(
    n2o_change_pct[valid_vals],
    bins=class_edges[1:-1],
    right=False
) + 1

# 防止极少数边界值越界
effect_class_id = np.where(effect_class_id > 8, 8, effect_class_id)
effect_class_id = np.where(grass_mask, effect_class_id, NODATA_INT)

# =========================================================
# 8. 输出 CSV（仅草地像元）
# =========================================================
df = pd.DataFrame({
    "lon": lon2d[grass_mask].ravel(),
    "lat": lat2d[grass_mask].ravel(),
    "grassland_mask": grass_mask[grass_mask].astype(int).ravel(),
    "temp_score": temp_score[grass_mask].ravel(),
    "moisture_score": moisture_score[grass_mask].ravel(),
    "hydrothermal_score": hydrothermal_score[grass_mask].ravel(),
    "freeze_strength": freeze_strength[grass_mask].ravel(),
    "effect_lnrr": effect_lnrr[grass_mask].ravel(),
    "n2o_change_pct": n2o_change_pct[grass_mask].ravel(),
    "effect_class_id": effect_class_id[grass_mask].ravel()
})

# 给出每个像元的理论范围边界，便于检查
df["xmin"] = df["lon"] - RES / 2
df["xmax"] = df["lon"] + RES / 2
df["ymin"] = df["lat"] - RES / 2
df["ymax"] = df["lat"] + RES / 2

# 保留精度
for col in ["lon", "lat", "xmin", "xmax", "ymin", "ymax"]:
    df[col] = df[col].round(2)

for col in ["temp_score", "moisture_score", "hydrothermal_score", "freeze_strength"]:
    df[col] = df[col].round(4)

df["effect_lnrr"] = df["effect_lnrr"].round(4)
df["n2o_change_pct"] = df["n2o_change_pct"].round(2)

df.to_csv(OUTPUT_CSV, index=False, encoding="utf-8-sig")

# =========================================================
# 9. 输出 GeoTIFF
# =========================================================
def write_tif(path, arr, dtype, nodata):
    with rasterio.open(
        path,
        "w",
        driver="GTiff",
        height=HEIGHT,
        width=WIDTH,
        count=1,
        dtype=dtype,
        crs=CRS,
        transform=transform,
        nodata=nodata,
        compress="lzw"
    ) as dst:
        dst.write(arr.astype(dtype), 1)

write_tif(OUTPUT_TIF_LNRR, effect_lnrr, "float32", NODATA_FLOAT)
write_tif(OUTPUT_TIF_PCT, n2o_change_pct, "float32", NODATA_FLOAT)
write_tif(OUTPUT_TIF_CLASS, effect_class_id, "int16", NODATA_INT)

# =========================================================
# 10. 输出 ArcGIS 8级颜色表
# 建议：负值(蓝) -> 正值(红)
# =========================================================
# 颜色从深蓝到深红，过渡较自然
class_colors = {
    1: "#2166AC",
    2: "#4393C3",
    3: "#92C5DE",
    4: "#D1E5F0",
    5: "#FDDBC7",
    6: "#F4A582",
    7: "#D6604D",
    8: "#B2182B"
}

legend_df = pd.DataFrame({
    "class_id": list(range(1, 9)),
    "lower_pct": class_edges[:-1],
    "upper_pct": class_edges[1:],
    "color_hex": [class_colors[i] for i in range(1, 9)]
})

legend_df["lower_pct"] = legend_df["lower_pct"].round(2)
legend_df["upper_pct"] = legend_df["upper_pct"].round(2)
legend_df.to_csv(OUTPUT_LEGEND, index=False, encoding="utf-8-sig")

# =========================================================
# 11. 可选：输出 PNG 预览图
# =========================================================
# 连续色带
continuous_colors = [
    "#2166AC", "#4393C3", "#92C5DE", "#D1E5F0",
    "#F7F7F7",
    "#FDDBC7", "#F4A582", "#D6604D", "#B2182B"
]
cmap_cont = LinearSegmentedColormap.from_list("n2o_div", continuous_colors, N=256)

# 分类色带
class_color_list = [class_colors[i] for i in range(1, 9)]
cmap_class = LinearSegmentedColormap.from_list("class8", class_color_list, N=8)
norm_class = BoundaryNorm(class_edges, cmap_class.N)

plot_arr = np.where(grass_mask, n2o_change_pct, np.nan)

fig = plt.figure(figsize=(16, 8))
ax = plt.gca()

im = ax.imshow(
    plot_arr,
    extent=[WEST, EAST, SOUTH, NORTH],
    origin="upper",
    cmap=cmap_cont,
    vmin=pct_min,
    vmax=pct_max,
    interpolation="bilinear"
)

ax.set_title("Predicted Global Grassland Grazing Effect on N$_2$O Emission\n((e$^{lnRR}$ - 1) × 100, 0.5°)", fontsize=14)
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")

cb = plt.colorbar(im, ax=ax, shrink=0.82, pad=0.02)
cb.set_label("N$_2$O change (%)")

plt.tight_layout()
plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight")
plt.close()

# =========================================================
# 12. 打印结果
# =========================================================
print("完成：")
print(f"CSV: {OUTPUT_CSV}")
print(f"TIF: {OUTPUT_TIF_LNRR}")
print(f"TIF: {OUTPUT_TIF_PCT}")
print(f"TIF: {OUTPUT_TIF_CLASS}")
print(f"Legend: {OUTPUT_LEGEND}")
print(f"PNG: {OUTPUT_PNG}")
print()
print("8级分类边界（%）:")
for i in range(8):
    print(f"Class {i+1}: {class_edges[i]:.2f} ~ {class_edges[i+1]:.2f}")
print()
print(f"草地像元数: {len(df)}")
print(df.head(10))
