import cv2
from cv2.typing import Scalar
import numpy as np
import time

has_cuda = cv2.cuda.getCudaEnabledDeviceCount()
print(f"has cuda: {has_cuda}, needs to be 1")


if has_cuda == 0:
    print("you dont have cuda installed stupid fuck")

current_left = "assets/vids/Left_0009.mp4"
current_right = "assets/vids/Right_0009.mp4"

final = "Final.mp4"
delay = 0

CLIP_START_FROM = 0
CLIP_OFFSET = 1

delay = int(1000 / 25)

global clip_left
clip_left = None

global clip_right
clip_right = None

total_time_taken = 0
index = 0

current_video_index_left = 0
current_video_index_right = 0

# allocate memory on the gpu
frame1_gpu = cv2.cuda.GpuMat()
frame2_gpu = cv2.cuda.GpuMat()
final_frame_gpu = cv2.cuda.GpuMat()


# Initial zoom and shift
def shift_image_vertically(image, shift):
    """
    Shifts an image along the y-axis.

    Parameters:
    - image: The input image (2D or 3D array).
    - shift: The number of pixels to shift. Positive values shift downward,
        negative values shift upward.

    Returns:
    - Shifted image.
    """
    # initialize bc math is slowy

    if shift > 0:  # Shift downward
        return np.pad(image, pad_width=pad_width, mode="constant")[:-shift]
    elif shift < 0:  # Shift upward
        return np.pad(image, pad_width=pad_width, mode="constant")[-shift:]
    else:  # No shift
        return image


shift = 10
ab = abs(shift)
pad_width = ((ab, 0), (0, 0), (0, 0))


def stitch_images(image_left, image_right, right_x_shift, right_y_shift):
    image_right = shift_image_vertically(image_right, right_y_shift)
    # faster method for horizontal stitching
    return np.hstack((image_left, image_right))

    # return np.concatenate((image_left, image_right), axis=1)


def init_undistort(image, width, height, map1, map2, cof):
    camera_matrix = np.array(
        [[3000, 0, image.shape[1] / 2], [0, 3000, image.shape[0] / 2], [0, 0, 1]],
        dtype="double",
    )
    dist_cof = np.array([cof[0], cof[1], cof[2], 0, 0])

    map1, map2 = cv2.initUndistortRectifyMap(
        camera_matrix,
        dist_cof,
        None,
        camera_matrix,
        (width, height),
        5,
    )
    return map1, map2


def maps_to_cuda_maps(map1, map2):
    map1_gpu, map2_gpu = cv2.cuda.GpuMat(map1), cv2.cuda.GpuMat(map2)
    return map1_gpu, map2_gpu


def get_from_maps_cuda(image, map1, map2):
    image = cv2.cuda.remap(
        image,
        map1,
        map2,
        interpolation=cv2.INTER_CUBIC,
        borderMode=cv2.BORDER_CONSTANT,
    )
    return image


def open_new_clip(path):
    return cv2.VideoCapture(path)


def get_next_left(file_prefix, current_index):
    new_path = ""

    new_path = f"assets/vids/{file_prefix}_{current_index:04d}.mp4"

    print(f"trying to open: {new_path}")
    return open_new_clip(new_path)


def start():
    current_video_index_left = 0

    current_video_index_right = 0

    frame1_gpu = cv2.cuda.GpuMat()
    frame2_gpu = cv2.cuda.GpuMat()

    global clip_left
    clip_left = cv2.VideoCapture(current_left)
    global clip_right
    clip_right = cv2.VideoCapture(current_right)

    length = int(clip_left.get(cv2.CAP_PROP_FRAME_COUNT))

    global index
    index = 0
    _, f = clip_left.read()

    width = 4000
    height = 4000

    map1_left = None
    map2_left = None
    cof_left = [-0.199, -0.1300, -0.0150]
    map1_left, map2_left = init_undistort(
        f, width, height, map1_left, map2_left, cof_left
    )
    map1_left_gpu, map2_left_gpu = maps_to_cuda_maps(map1_left, map2_left)

    map1_right = None
    map2_right = None
    cof_right = [-0.4190, 0.0780, -0.0460]
    _, f2 = clip_right.read()
    map1_right, map2_right = init_undistort(
        f2, width, height, map1_right, map2_right, cof_right
    )
    map1_right_gpu, map2_right_gpu = maps_to_cuda_maps(map1_right, map2_right)

    clip_left.set(cv2.CAP_PROP_POS_FRAMES, CLIP_START_FROM + CLIP_OFFSET)
    clip_right.set(cv2.CAP_PROP_POS_FRAMES, CLIP_START_FROM)

    video_format = cv2.VideoWriter_fourcc("M", "J", "P", "G")
    global final_clip
    final_clip = cv2.VideoWriter(
        "final.avi", video_format, 25, (int(4000 * 2), int(1975 * 2))
    )

    fps = int(1000 / 25)

    showvid = True
    savevid = True
    while True:
        frame_start = time.time()

        index += 1

        ret1, frame1 = clip_left.read()
        ret2, frame2 = clip_right.read()

        if not ret1 or not ret2:
            # either 1 of the video's ended, break out of it
            if not ret1:
                current_video_index_left += 1
                clip_left = get_next_left("Left", current_video_index_left)

            if not clip_left:
                break

            if not ret2:
                current_video_index_right += 1
                clip_right = get_next_left("Right", current_video_index_right)
            if not clip_right:
                break

        frame1_gpu.upload(frame1)
        frame2_gpu.upload(frame2)

        frame1_gpu = get_from_maps_cuda(frame1_gpu, map1_left_gpu, map2_left_gpu)
        frame2_gpu = get_from_maps_cuda(frame2_gpu, map1_right_gpu, map2_right_gpu)

        # frame1_gpu = cv2.cuda.resize(
        #    frame1_gpu, (2000, 1975 + 50), interpolation=cv2.INTER_AREA
        # )
        # frame2_gpu = cv2.cuda.resize(
        #    frame2_gpu, (2000, 1975 + 50), interpolation=cv2.INTER_AREA
        # )

        frame1 = frame1_gpu.download()
        frame2 = frame2_gpu.download()

        final_frame = stitch_images(frame1, frame2, 0, 10)
        final_frame = final_frame[50:]

        if savevid:
            final_clip.write(final_frame)

        frame_end = time.time()
        time_taken_frame = frame_end - frame_start

        k = 0

        if showvid:
            cv2.imshow("vid", final_frame)
            k = cv2.waitKey(fps)

        global total_time_taken
        total_time_taken += time_taken_frame
        # print(f"total time: {total_time_taken}")
        avg = total_time_taken / index
        print(f"avg: {avg}")
        print(f"index {index} of {length}")
        # print("//// end frame ////")

        print(k)
        if k == 27:
            break

    print("save video")
    clip_left.release()
    clip_right.release()
    final_clip.release()
    print("saved")
    print(f"total time taken: {total_time_taken}")
    avg = total_time_taken / index
    print(f"avg time: {avg}")


start()
# print(f"total time taken: {total_time_taken}")
# avg = total_time_taken / index
# print(f"avg frame time: {avg}")
