"""
批量分析长条形磁铁摆动视频
逐帧检测磁铁中心坐标，输出到 Excel（原点在视频左下角，单位：像素）
"""

import cv2
import numpy as np
import openpyxl
import os
import glob
import sys


# ─────────────────────── 可调参数 ────────────────────────
# 红色N极 HSV范围
RED_LOWER1 = np.array([0,   100, 80])
RED_UPPER1 = np.array([10,  255, 255])
RED_LOWER2 = np.array([165, 100, 80])
RED_UPPER2 = np.array([180, 255, 255])

# 轴线扫描参数
SCAN_RANGE     = 500     # 从N极中心沿轴线最大扫描距离（像素）
MAGNET_GRAY_THR = 130    # 灰度低于此值 = 磁铁，高于 = 背景
STRIP_PAD      = 5       # 垂直采样条带额外宽度（像素）

# 形态学核大小（去噪）
MORPH_K = 7

# 最小有效轮廓面积（像素²），过滤噪点
MIN_AREA = 300

# 红色N极的最小长宽比（用于排除非磁铁的红色噪点）
RED_MIN_ASPECT = 1.5

# 每个视频保存多少张预览图（均匀抽帧），0 表示不保存
PREVIEW_COUNT = 10
# ──────────────────────────────────────────────────────────


def detect_magnet(frame: np.ndarray, return_debug=False):
    """
    在单帧中检测长条形磁铁的中心（图像坐标系，Y 向下）。
    return_debug=False: 返回 (cx, cy) 或 None
    return_debug=True : 返回 ((cx, cy), box_pts, debug_mask) 或 (None, None, None)

    算法：
      1. 检测红色 N 极 → 定位参考 + 长轴方向
      2. 沿长轴采样灰度剖面，找到磁铁两端（暗→亮跳变）
      3. 两端中点 = 矩形轮廓几何中心
    """
    h_f, w_f = frame.shape[:2]
    hsv  = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    k    = cv2.getStructuringElement(cv2.MORPH_RECT, (MORPH_K, MORPH_K))

    _fail = (None, None, None) if return_debug else None

    # ── 1. 检测红色 N 极 ──────────────────────────────────────────
    mask_r1  = cv2.inRange(hsv, RED_LOWER1, RED_UPPER1)
    mask_r2  = cv2.inRange(hsv, RED_LOWER2, RED_UPPER2)
    mask_red = cv2.bitwise_or(mask_r1, mask_r2)
    mask_red = cv2.morphologyEx(mask_red, cv2.MORPH_CLOSE, k)
    mask_red = cv2.morphologyEx(mask_red, cv2.MORPH_OPEN,  k)

    contours_r, _ = cv2.findContours(
        mask_red, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours_r = [c for c in contours_r if cv2.contourArea(c) >= MIN_AREA]
    if not contours_r:
        return _fail

    best_red  = max(contours_r, key=cv2.contourArea)
    red_rect  = cv2.minAreaRect(best_red)
    cx_r, cy_r = red_rect[0]
    w_r, h_r   = red_rect[1]

    if min(w_r, h_r) == 0:
        return _fail
    if max(w_r, h_r) / min(w_r, h_r) < RED_MIN_ASPECT:
        return _fail

    # ── 2. 计算 N 极长轴方向 ─────────────────────────────────────
    angle_r = red_rect[2]
    if w_r >= h_r:
        angle_rad = np.deg2rad(angle_r)
    else:
        angle_rad = np.deg2rad(angle_r + 90.0)
    short_half = min(w_r, h_r) / 2.0

    dx  = float(np.cos(angle_rad))
    dy  = float(np.sin(angle_rad))
    pdx = -dy                          # 垂直方向
    pdy =  dx

    # ── 3. 沿轴线采样灰度剖面（向量化） ─────────────────────────
    half_w  = int(short_half) + STRIP_PAD
    scan    = np.arange(-SCAN_RANGE, SCAN_RANGE + 1, dtype=float)
    perp    = np.arange(-half_w, half_w + 1, dtype=float)

    sx = np.round(cx_r + dx * scan[:, None] + pdx * perp[None, :]).astype(int)
    sy = np.round(cy_r + dy * scan[:, None] + pdy * perp[None, :]).astype(int)
    np.clip(sx, 0, w_f - 1, out=sx)
    np.clip(sy, 0, h_f - 1, out=sy)

    samples     = gray[sy, sx]
    profile     = samples.mean(axis=1)          # 灰度均值剖面
    profile_var = samples.var(axis=1)           # 灰度方差剖面（纹理）
    red_profile = mask_red[sy, sx].mean(axis=1) # 红色掩膜剖面

    # ── 4. 分别确定 N 极端和 S 极端 ─────────────────────────────
    center_idx = SCAN_RANGE
    RED_EDGE_THR = 30   # 红色掩膜均值低于此 → 不再是红色
    TEX_WIN = 30        # 纹理检测窗口（像素数）

    if profile[center_idx] >= MAGNET_GRAY_THR:
        return _fail

    # 4a. 红色掩膜两个方向的消失点
    red_pos_end = len(scan) - 1
    for i in range(center_idx + 1, len(scan)):
        if red_profile[i] < RED_EDGE_THR:
            red_pos_end = i
            break

    red_neg_end = 0
    for i in range(center_idx - 1, -1, -1):
        if red_profile[i] < RED_EDGE_THR:
            red_neg_end = i
            break

    # 4b. 判断 S 极方向：在红色消失后，扫描一个 N 极长度的窗口
    #     S 极 = 纯黑，暗像素占比接近 1.0
    #     收纳盒/背景 = 有亮色间隙，暗像素占比 < S 极
    n_pole_len = int(max(w_r, h_r))  # N 极的完整长度（像素）
    SKIP = 8                         # 跳过红色消失点附近的过渡区

    def _dark_fraction(start, direction):
        idxs = [start + direction * (j + SKIP)
                for j in range(n_pole_len)
                if 0 <= start + direction * (j + SKIP) < len(profile)]
        if not idxs:
            return 0.0
        return float((profile[np.array(idxs)] < MAGNET_GRAY_THR).mean())

    frac_pos = _dark_fraction(red_pos_end, +1)  # +方向暗像素占比
    frac_neg = _dark_fraction(red_neg_end, -1)  # -方向暗像素占比

    # 暗像素占比更高的一侧 = S 极（纯黑均匀）
    if frac_neg >= frac_pos:
        s_scan_dir = -1          # S 极在 neg 方向
        n_end_idx  = red_pos_end # N 极自由端在 pos 方向
    else:
        s_scan_dir = +1
        n_end_idx  = red_neg_end

    # 4c. S 极端：沿 S 极方向做灰度梯度扫描
    rough_s = center_idx
    if s_scan_dir == -1:
        for i in range(center_idx - 1, -1, -1):
            if profile[i] >= MAGNET_GRAY_THR:
                rough_s = i
                break
    else:
        for i in range(center_idx + 1, len(scan)):
            if profile[i] >= MAGNET_GRAY_THR:
                rough_s = i
                break

    win  = 30
    grad = np.diff(profile)
    g1 = max(0, rough_s - win)
    g2 = min(len(grad), rough_s + win)
    s_end_idx = g1 + int(np.argmax(np.abs(grad[g1:g2])))
    if s_scan_dir == -1:
        s_end_idx += 1   # diff 偏移校正

    left_idx  = min(n_end_idx, s_end_idx)
    right_idx = max(n_end_idx, s_end_idx)

    left_d  = scan[left_idx]
    right_d = scan[right_idx]
    mid_d   = (left_d + right_d) / 2.0
    mag_len = right_d - left_d

    if mag_len < 50:                       # 磁铁长度至少50px，否则视为噪点
        return _fail

    cx = int(round(cx_r + dx * mid_d))
    cy = int(round(cy_r + dy * mid_d))

    # ── 5. 构造包围框 ────────────────────────────────────────────
    box_angle = float(np.rad2deg(np.arctan2(dy, dx)))
    box_rect  = ((cx_r + dx * mid_d, cy_r + dy * mid_d),
                 (float(mag_len), (short_half + STRIP_PAD) * 2),
                 box_angle)
    box_pts = cv2.boxPoints(box_rect).astype(np.int32)

    if return_debug:
        debug_mask = np.zeros((h_f, w_f), dtype=np.uint8)
        d_rng  = np.arange(left_d, right_d + 1, dtype=float)
        mx = np.round(cx_r + dx * d_rng[:, None] + pdx * perp[None, :]).astype(int)
        my = np.round(cy_r + dy * d_rng[:, None] + pdy * perp[None, :]).astype(int)
        valid = (mx >= 0) & (mx < w_f) & (my >= 0) & (my < h_f)
        debug_mask[my[valid], mx[valid]] = 255
        return (cx, cy), box_pts, debug_mask
    return cx, cy


def draw_preview(frame: np.ndarray, center, box_pts, mask,
                 frame_idx: int, t: float, height: int) -> np.ndarray:
    """在帧上叠加检测结果，返回标注后的图像。"""
    vis = frame.copy()

    if center is not None:
        cx, cy = center
        x_out = cx
        y_out = height - 1 - cy

        # 画旋转包围框（绿色）
        cv2.drawContours(vis, [box_pts], 0, (0, 255, 0), 3)

        # 画十字准线（黄色）
        arm = 40
        cv2.line(vis, (cx - arm, cy), (cx + arm, cy), (0, 255, 255), 3)
        cv2.line(vis, (cx, cy - arm), (cx, cy + arm), (0, 255, 255), 3)

        # 画中心点（红色实心圆）
        cv2.circle(vis, (cx, cy), 10, (0, 0, 255), -1)

        # 文字信息
        label = f"Frame {frame_idx}  t={t:.3f}s"
        coord = f"X={x_out}px  Y={y_out}px (origin:bottom-left)"
        font  = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(vis, label, (20, 60),  font, 1.8, (0, 0, 0),   6)
        cv2.putText(vis, label, (20, 60),  font, 1.8, (255,255,255), 3)
        cv2.putText(vis, coord, (20, 120), font, 1.8, (0, 0, 0),   6)
        cv2.putText(vis, coord, (20, 120), font, 1.8, (0, 255, 255), 3)
    else:
        font = cv2.FONT_HERSHEY_SIMPLEX
        msg  = f"Frame {frame_idx}  t={t:.3f}s  [未检测到磁铁]"
        cv2.putText(vis, msg, (20, 60), font, 1.8, (0, 0, 0),   6)
        cv2.putText(vis, msg, (20, 60), font, 1.8, (0, 80, 255), 3)

    return vis


def process_video(video_path: str, ws, preview_dir: str = None):
    """逐帧处理一个视频，结果写入 openpyxl worksheet。
    preview_dir: 若非 None，则在该目录保存均匀抽帧的预览图。
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"  [错误] 无法打开视频: {video_path}")
        return

    fps    = cap.get(cv2.CAP_PROP_FPS)
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total  = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # 计算需要保存预览的帧索引集合（均匀分布）
    preview_indices = set()
    if preview_dir and PREVIEW_COUNT > 0 and total > 0:
        os.makedirs(preview_dir, exist_ok=True)
        step = max(1, total // PREVIEW_COUNT)
        preview_indices = set(range(0, total, step))
        preview_indices = set(list(sorted(preview_indices))[:PREVIEW_COUNT])

    ws.append(['时间(s)', 'X(像素)', 'Y(像素)'])

    frame_idx = 0
    found_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        t = round(frame_idx / fps, 6)
        need_debug = (preview_dir is not None) and (frame_idx in preview_indices)

        if need_debug:
            center, box_pts, mask = detect_magnet(frame, return_debug=True)
            result = center
        else:
            result = detect_magnet(frame)

        if result is not None:
            cx, cy = result
            x = cx
            y = height - 1 - cy   # 原点翻转到左下角
            ws.append([t, x, y])
            found_count += 1
        else:
            ws.append([t, None, None])

        # 保存预览图
        if need_debug:
            vis = draw_preview(frame, center, box_pts if result else None,
                               mask if result else None,
                               frame_idx, t, height)
            img_name = f"{frame_idx:06d}_t{t:.3f}s.jpg"
            ok, buf = cv2.imencode('.jpg', vis, [cv2.IMWRITE_JPEG_QUALITY, 90])
            if ok:
                with open(os.path.join(preview_dir, img_name), 'wb') as f:
                    f.write(buf.tobytes())

        frame_idx += 1

        if frame_idx % 200 == 0:
            pct = frame_idx / total * 100 if total > 0 else 0
            print(f"  进度: {frame_idx}/{total} ({pct:.1f}%)  "
                  f"已检测到磁铁: {found_count} 帧", end='\r')

    cap.release()
    print(f"\n  完成: 共 {frame_idx} 帧，检测到磁铁 {found_count} 帧")


def main():
    video_dir   = os.path.join(os.path.dirname(__file__), '视频')
    output_path = os.path.join(os.path.dirname(__file__), '磁铁坐标.xlsx')

    exts = ('*.mp4', '*.avi', '*.mov', '*.mkv', '*.MP4', '*.AVI', '*.MOV')
    video_files = []
    for ext in exts:
        video_files.extend(glob.glob(os.path.join(video_dir, ext)))
    video_files = sorted(set(video_files))

    if not video_files:
        print(f"在 {video_dir} 中未找到视频文件")
        sys.exit(1)

    print(f"找到 {len(video_files)} 个视频文件")

    wb = openpyxl.Workbook()
    wb.remove(wb.active)   # 删除默认空 sheet

    preview_root = os.path.join(os.path.dirname(__file__), '检测预览')

    for vf in video_files:
        name = os.path.splitext(os.path.basename(vf))[0]
        # Excel 工作表名最长 31 字符，且不含特殊符号
        sheet_name = name[:31].replace('/', '_').replace('\\', '_').replace('?', '_') \
                               .replace('*', '_').replace('[', '_').replace(']', '_') \
                               .replace(':', '_')
        print(f"\n处理视频: {name}")
        ws = wb.create_sheet(title=sheet_name)
        preview_dir = os.path.join(preview_root, name) if PREVIEW_COUNT > 0 else None
        process_video(vf, ws, preview_dir=preview_dir)

    wb.save(output_path)
    print(f"\nExcel 已保存至: {output_path}")
    if PREVIEW_COUNT > 0:
        print(f"预览图已保存至: {preview_root}")


if __name__ == '__main__':
    main()
