#!/usr/bin/env python3
import argparse
import pathlib
import requests
import shutil
import subprocess
import sys
import uuid
from datetime import datetime
from argparse import ArgumentDefaultsHelpFormatter, RawTextHelpFormatter
import re


class SmartFormatter(ArgumentDefaultsHelpFormatter, RawTextHelpFormatter):
    pass


VERSION = "1.10.18"
CLIENT_ID_FILE = pathlib.Path(__file__).parent / ".client-id"
REGION_FILE = pathlib.Path(__file__).parent / ".region"

AWS_REGIONS = {
    "us-east-1",
    "us-east-2",
    "us-west-1",
    "us-west-2",
    "af-south-1",
    "ap-east-1",
    "ap-south-1",
    "ap-south-2",
    "ap-southeast-1",
    "ap-southeast-2",
    "ap-southeast-3",
    "ap-southeast-4",
    "ap-northeast-1",
    "ap-northeast-2",
    "ap-northeast-3",
    "ca-central-1",
    "eu-central-1",
    "eu-central-2",
    "eu-north-1",
    "eu-south-1",
    "eu-south-2",
    "eu-west-1",
    "eu-west-2",
    "eu-west-3",
    "il-central-1",
    "me-central-1",
    "me-south-1",
    "sa-east-1",
}


def format_time(timestr: str) -> str:
    """Convert ISO8601 string to YYYY-MM-DD HH:MM:SS (UTC).
    Returns '' if input is empty or invalid."""
    if not timestr:
        return ""
    try:
        dt = datetime.fromisoformat(timestr.replace("Z", "+00:00"))
        return dt.strftime("%Y-%m-%d %H:%M:%S")
    except Exception:
        return timestr  # fallback: show raw


def print_executions_table(title: str, executions: list[dict]):
    if not executions:
        print(f"{title}: none")
        return
    print(f"{title}:")
    header = (
        f"{'Execution ID':<36}    {'Start Time (UTC)':<19}    {'Stop Time (UTC)':<19}"
    )
    print(header)
    print("-" * len(header))
    for exe in executions:
        print(
            f"{exe['execution_id']:<36}    {format_time(exe['startTime']):<19}    {format_time(exe['stopTime']):<19}"
        )
    print()



def parse_s3_path(s3_path: str):
    """
    To avoid weird errors from AWS Step Functions, and SSMs, we validate the path
    allowedPattern: "^s3://[a-zA-Z0-9\\-_/]+$"
    The reason to add this function is to avoid characters like `.` or other invalid characters in the path
    that lead to hanging step functions and confusing errors later on.

    Returns a tuple (is_valid: bool, message_or_path: str)
    """
    pattern = r"^s3://[a-zA-Z0-9\-_/]+$"
    if not re.match(pattern, s3_path):
        return (
            False,
            f"s3-path '{s3_path}' is not valid. It must match the pattern {pattern}",
        )
    return True, s3_path



def validate_framewise_trinsics(s3_path: str, trinsics_name="extrinsics") -> str:
    """Validate that extrinsics or intrinsics files exist for all topbot images when using
    --framewise-extrinsics or --framewise-intrinsics flags.
    Returns error message if validation fails, empty string otherwise."""
    if not shutil.which("aws"):
        return ""  # Skip validation if AWS CLI not available
    try:
        # List topbot files
        topbot_prefix = s3_path.rstrip("/") + "/topbot/"
        try:
            topbot_output = subprocess.check_output(
                ["aws", "s3", "ls", topbot_prefix],
                text=True,
                stderr=subprocess.PIPE,
            )
        except subprocess.CalledProcessError:
            return f"Topbot folder not found at {topbot_prefix}"
        topbot_files = [
            line.split()[-1]
            for line in topbot_output.strip().split("\n")
            if line and line.split()[-1].endswith(".tiff")
        ]
        if not topbot_files:
            return f"No topbot images found in {topbot_prefix}"
        # List trinsics files
        trinsics_prefix = s3_path.rstrip("/") + f"/{trinsics_name}/"
        try:
            trinsics_output = subprocess.check_output(
                ["aws", "s3", "ls", trinsics_prefix],
                text=True,
                stderr=subprocess.PIPE,
            )
            trinsics_files = [
                line.split()[-1]
                for line in trinsics_output.strip().split("\n")
                if line and line.split()[-1].endswith(".yaml")
            ]
        except subprocess.CalledProcessError:
            # Extrinsics folder doesn't exist or is empty
            return f"{trinsics_name.capitalize()} folder not found at {trinsics_prefix}"
        # Check that each topbot has a corresponding trinsics file
        missing = []
        for topbot in topbot_files:
            frame_id = topbot.replace(".tiff", "")
            trinsics_file = f"{frame_id}.yaml"
            if trinsics_file not in trinsics_files:
                missing.append(frame_id)
        if missing:
            return f"Missing {trinsics_name} files for {len(missing)} frame(s): {', '.join(missing)}"
        return ""
    except subprocess.CalledProcessError:
        return ""  # Skip validation if topbot listing fails
    except Exception as e:
        print(f"Warning: could not validate framewise {trinsics_name}: {e}")
        return ""


def validate_client_id(client_id: str):
    try:
        val = uuid.UUID(client_id, version=4)
    except ValueError:
        return False, f"{client_id} is not a valid UUID v4"
    if str(val) != client_id:
        return False, f"{client_id} is not in canonical UUID v4 format"
    return True, client_id


def validate_target_subdomain(subdomain: str):
    if subdomain in AWS_REGIONS:
        return True, subdomain
    if subdomain.startswith("dev-") or subdomain.startswith("test-"):
        suffix = "-".join(subdomain.split("-")[-3:])
        if suffix in AWS_REGIONS:
            return True, subdomain
    return False, f"{subdomain} is not a valid AWS region or dev/test prefixed region"


def get_or_prompt(file_path: pathlib.Path, arg_value: str, prompt: str, validate_func):
    if arg_value:
        ok, result = validate_func(arg_value)
        if not ok:
            return None, result
        file_path.write_text(result)
        return result, None
    if file_path.exists():
        value = file_path.read_text().strip()
        ok, result = validate_func(value)
        if not ok:
            return None, result
        return result, None
    value = input(f"{prompt}: ").strip()
    ok, result = validate_func(value)
    if not ok:
        return None, result
    file_path.write_text(result)
    return result, None


def get_client_and_region(client_id, region):
    client_id, err = get_or_prompt(
        CLIENT_ID_FILE, client_id, "Enter client-id (UUID v4)", validate_client_id
    )
    if not client_id:
        print(f"Error: {err}")
        return None, None, 1
    region, err = get_or_prompt(
        REGION_FILE, region, "Enter region", validate_target_subdomain
    )
    if not region:
        print(f"Error: {err}")
        return None, None, 1
    return client_id, region, 0


def build_base_url(subdomain: str) -> str:
    return f"https://cloud.{subdomain}.nodarsensor.net"


def print_error(url, response):
    print(f"Error: request to {url} failed with status {response.status_code}")
    try:
        data = response.json()
        if isinstance(data, dict) and "body" in data:
            print("\n\t", data["body"])
        else:
            print("\n\t", data)
    except Exception:
        print("\n\t", response.text)
    if response.status_code == 400 and response.text == "Bad Request":
        print("\nIt is likely that either your UUID or region is incorrect.")


FLAGS = {
    "common": [
        "--save-left-disparity",
        "--save-right-disparity",
        "--save-left-rectified",
        "--save-right-rectified",
        "--save-left-valid-pixel-map",
        "--save-right-valid-pixel-map",
        "--save-details",
        "--save-pc",
        "--disable-autocal",
        "--framewise-extrinsics",
        "--framewise-intrinsics",
    ],
    "ground_truth_only": [],
    "hammerhead_only": ["--save-left-confidence-map", "--save-right-confidence-map"],
    "unsupported": [],
}


def generate_flags_help() -> str:
    def fmt(name, flags):
        return f"  {name}:\n        " + "\n        ".join(flags) if flags else None

    sections = [
        fmt("Common", FLAGS["common"]),
        fmt("Ground Truth Only", FLAGS["ground_truth_only"]),
        fmt(
            "Hammerhead Only (since ground-truth does not generate confidence maps)",
            FLAGS["hammerhead_only"],
        ),
    ]
    sections = [s for s in sections if s]  # drop empty
    return (
        "Processing flags (space-separated).\n"
        + "\n".join(f"  {line}" for line in sections)
        + "\n"
    )


def validate_flags(matcher, flags):
    """Validate that flags are compatible with the chosen matcher."""
    flag_list = flags.split()
    for flag in flag_list:
        if flag in FLAGS["unsupported"]:
            return f"Flag '{flag}' is not implemented"
        elif flag in FLAGS["ground_truth_only"] and matcher != "ground-truth":
            return f"Flag '{flag}' is only supported with ground-truth matcher, not {matcher}"
        elif flag in FLAGS["hammerhead_only"] and matcher != "hammerhead":
            return f"Flag '{flag}' is only supported with hammerhead matcher, not {matcher}"
        elif not any(flag in group for group in FLAGS.values()):
            return f"Unknown flag: '{flag}'"
    return None  # No error


def start(
    client_id, region, s3_path, start_frame, frame_count, matcher, flags, pixel_format, max_disp, split_network
):
    client_id, region, exit_code = get_client_and_region(client_id, region)
    if exit_code != 0:
        return exit_code
    # Validate max-disp
    try:
        max_disp_int = int(max_disp)
    except ValueError:
        print(f"Error: --max-disp must be an integer (got '{max_disp}')")
        return 1
    if max_disp_int != 0 and max_disp_int % 32 != 0:
        down = (max_disp_int // 32) * 32
        up = down + 32
        print(f"Error: --max-disp must be a multiple of 32 (got {max_disp_int}). Try {down} or {up}.")
        return 1
    if max_disp_int > 608:
        print(f"Error: --max-disp must be at most 608 (got {max_disp_int}).")
        return 1
    # Validate flags are compatible with matcher
    flag_error = validate_flags(matcher, flags)
    if flag_error:
        print(f"Error: {flag_error}")
        return 1
    if s3_path.startswith("s3://"):
        if not parse_s3_path(s3_path)[0]:
            print(
                "ERROR: An S3 path may only contain letters, numbers, hyphens, underscores, and slashes -- no `.` or spaces"
            )
            raise ValueError(parse_s3_path(s3_path)[1])
        # Validate framewise extrinsics if flag is present
        if "--framewise-extrinsics" in flags:
            print("Validating framewise extrinsics...")
            extrinsics_error = validate_framewise_trinsics(
                s3_path, trinsics_name="extrinsics"
            )
            if extrinsics_error:
                print(f"Error: {extrinsics_error}")
                return 1
            print("All topbot images have corresponding extrinsics files")
        if "--framewise-intrinsics" in flags:
            print("Validating framewise intrinsics...")
            intrinsics_error = validate_framewise_trinsics(
                s3_path, trinsics_name="intrinsics"
            )
            if intrinsics_error:
                print(f"Error: {intrinsics_error}")
                return 1
            print("All topbot images have corresponding intrinsics files")
    url = f"{build_base_url(region)}/start-execution"
    headers = {
        "Customer-ID": client_id,
        "S3-Path": s3_path,
        "Start-Frame": start_frame,
        "Frame-Count": frame_count,
        "Matcher": matcher,
        "Flags": flags,
        "Pixel-Format": pixel_format,
        "Max-Disp": max_disp,
        "Split-Network": "1" if split_network else "0",
    }
    try:
        response = requests.post(url, headers=headers)
    except Exception as e:
        print(f"Error: could not connect to {url} - {e}")
        return 1
    if response.ok:
        try:
            data = response.json()
            if msg := data.get("message", ""):
                print(msg)
            print(f"Started process with execution ID: {data.get('executionId')}")
        except Exception:
            print("Warning: response body was not valid JSON")
    else:
        print_error(url, response)
    return 0 if response.ok else 1


def stop(client_id, region, execution_id):
    client_id, region, exit_code = get_client_and_region(client_id, region)
    if exit_code != 0:
        return exit_code
    url = f"{build_base_url(region)}/stop-execution"
    headers = {
        "Customer-ID": client_id,
        "Execution-ID": execution_id,
    }
    try:
        response = requests.post(url, headers=headers)
    except Exception as e:
        print(f"Error: could not connect to {url} - {e}")
        return 1
    if response.ok:
        print(f"Submitted stop request for process with execution ID: {execution_id}")
    else:
        print_error(url, response)
    return 0 if response.ok else 1


def principal(client_id, region):
    client_id, region, exit_code = get_client_and_region(client_id, region)
    if exit_code != 0:
        return exit_code
    url = f"{build_base_url(region)}/get-principal"
    headers = {
        "Customer-ID": client_id,
    }
    try:
        response = requests.get(url, headers=headers)
    except Exception as e:
        print(f"Error: could not connect to {url} - {e}")
        return 1
    if response.ok:
        try:
            data = response.json()
            print(f"Roles that need access to your S3 bucket:\n")
            print(f"EC2_ROLE_ARN: {data.get('ec2RoleArn')}")
        except Exception:
            print("Warning: response body was not valid JSON")
    else:
        print_error(url, response)
    return 0 if response.ok else 1


def status(client_id, region, execution_id):
    client_id, region, exit_code = get_client_and_region(client_id, region)
    if exit_code != 0:
        return exit_code
    url = f"{build_base_url(region)}/get-status"
    headers = {
        "Customer-ID": client_id,
        "Execution-ID": execution_id,
    }
    try:
        response = requests.get(url, headers=headers)
    except Exception as e:
        print(f"Error: could not connect to {url} - {e}")
        return 1
    if response.ok:
        try:
            data = response.json()
            status = data.get("status")
            print(f"Status of {execution_id}: {status}")

            # Show error details if available
            if status == "FAILED":
                error = data.get("error")
                cause = data.get("cause")
                if error:
                    print("\nPlease report this error to support@nodarsensor.com")
                    print("It is likely an internal AWS cloud error.\n")
                    print(f"Error: {error}")
                if cause:
                    print(f"Cause: {cause}")
        except Exception:
            print("Warning: response body was not valid JSON")
    else:
        print_error(url, response)
    return 0 if response.ok else 1


def cloud_version(client_id, region):
    """
    Query the API /cloud-version endpoint and print the deployed version.
    """
    client_id, region, exit_code = get_client_and_region(client_id, region)
    if exit_code != 0:
        return exit_code
    url = f"{build_base_url(region)}/cloud-version"
    headers = {"Customer-ID": client_id}
    try:
        response = requests.get(url, headers=headers)
    except Exception as e:
        print(f"Error: could not connect to {url} - {e}")
        return 1
    if response.ok:
        try:
            data = response.json()
            print(f"Cloud API version: {data.get('VERSION')}")
        except Exception:
            print("Warning: response body was not valid JSON")
    else:
        print_error(url, response)

    return 0 if response.ok else 1


def list_executions(client_id, region, endpoint, label):
    client_id, region, exit_code = get_client_and_region(client_id, region)
    if exit_code != 0:
        return exit_code
    url = f"{build_base_url(region)}/{endpoint}"
    headers = {"Customer-ID": client_id}
    try:
        response = requests.get(url, headers=headers)
    except Exception as e:
        print(f"Error: could not connect to {url} - {e}")
        return 1

    if response.ok:
        try:
            executions = response.json()
            # Always print in table form
            print_executions_table(f"{label.capitalize()} executions", executions)
        except Exception:
            print("Warning: response body was not valid JSON")
    else:
        print_error(url, response)
    return 0 if response.ok else 1


def add_common_args(p):
    p.add_argument("--client-id", metavar="ID", help="Customer UUID v4")
    p.add_argument("--region", help="AWS Region (e.g. us-east-1)")


def main():
    parser = argparse.ArgumentParser(
        description="Nodar Cloud CLI",
        formatter_class=SmartFormatter,
    )
    parser.add_argument(
        "-v", "--version", action="version", version=f"%(prog)s {VERSION}"
    )
    parsers = parser.add_subparsers(dest="command", required=True)
    add_common_args(
        parsers.add_parser("cloud-version", help="Show the version of the cloud API")
    )
    start_parser = parsers.add_parser(
        "start",
        help="Start execution",
        formatter_class=SmartFormatter,
    )
    add_common_args(start_parser)
    # s3-path can be passed as positional or as a flag
    start_parser.add_argument(
        "s3_path", nargs="?", help="S3 path like s3://bucket/prefix"
    )
    start_parser.add_argument(
        "--s3-path",
        dest="s3_path_opt",
        metavar="S3_PATH",
        help="S3 path like s3://bucket/prefix",
    )
    start_parser.add_argument(
        "--start-frame",
        metavar="FRAME",
        default="0",
        help="Starting frame number",
    )
    start_parser.add_argument(
        "--frame-count",
        metavar="COUNT",
        default="-1",
        help="Number of frames to process where -1 denotes 'all frames'",
    )
    start_parser.add_argument(
        "--matcher",
        default="ground-truth",
        help="Matcher type: 'ground-truth' or 'hammerhead'",
    )
    start_parser.add_argument(
        "--pixel-format",
        metavar="FMT",
        default="BGR",
        help="Input image pixel format: BGR, Bayer_RGGB, Bayer_GRBG, Bayer_BGGR, Bayer_GBRG",
    )
    start_parser.add_argument(
        "--max-disp",
        metavar="DISP",
        default="0",
        help="Maximum disparity. Must be a multiple of 32, up to 608. 0 means use the model default (416).",
    )
    start_parser.add_argument(
        "--split-network",
        action="store_true",
        default=False,
        help="Force the use of the split network (for memory-constrained systems). Normally auto-detected.",
    )
    start_parser.add_argument(
        "--flags",
        default="--save-left-disparity --save-details --save-left-rectified",
        help=generate_flags_help(),
    )
    stop_parser = parsers.add_parser("stop", help="Stop execution")
    add_common_args(stop_parser)
    # execution_id can be passed as a positional or as a flag
    stop_parser.add_argument("execution_id", nargs="?", help="Execution ID")
    stop_parser.add_argument(
        "--execution-id",
        dest="execution_id_opt",
        metavar="XID",
        help="Execution ID",
    )
    add_common_args(
        parsers.add_parser(
            "principal",
            help="Get principal (specifically, its ARN). Add this to your S3 bucket",
        )
    )
    status_parser = parsers.add_parser("status", help="Get status of execution")
    add_common_args(status_parser)
    # execution_id can be passed as a positional or as a flag
    status_parser.add_argument("execution_id", nargs="?", help="Execution ID")
    status_parser.add_argument(
        "--execution-id", dest="execution_id_opt", metavar="XID", help="Execution ID"
    )
    add_common_args(
        parsers.add_parser("running", help="List executions that are running")
    )
    add_common_args(
        parsers.add_parser("succeeded", help="List executions that succeeded")
    )
    add_common_args(
        parsers.add_parser("timed-out", help="List executions that timed-out")
    )
    add_common_args(
        parsers.add_parser("aborted", help="List executions that were aborted")
    )
    add_common_args(parsers.add_parser("failed", help="List executions that failed"))
    args = parser.parse_args()
    if args.command == "start":
        s3_path = args.s3_path or args.s3_path_opt
        if not s3_path:
            print("You must provide an s3-path")
            return 1
        return start(
            client_id=args.client_id,
            region=args.region,
            s3_path=s3_path,
            start_frame=args.start_frame,
            frame_count=args.frame_count,
            matcher=args.matcher,
            flags=args.flags,
            pixel_format=args.pixel_format,
            max_disp=args.max_disp,
            split_network=args.split_network,
        )
    elif args.command == "stop":
        execution_id = args.execution_id or args.execution_id_opt
        if not execution_id:
            print("You must provide an execution ID")
            return 1
        return stop(
            client_id=args.client_id, region=args.region, execution_id=execution_id
        )
    elif args.command == "principal":
        return principal(client_id=args.client_id, region=args.region)
    elif args.command == "status":
        execution_id = args.execution_id or args.execution_id_opt
        if not execution_id:
            print("You must provide an execution ID")
            return 1
        return status(
            client_id=args.client_id, region=args.region, execution_id=execution_id
        )
    elif args.command == "succeeded":
        return list_executions(
            args.client_id, args.region, "get-succeeded", "succeeded"
        )
    elif args.command == "failed":
        return list_executions(args.client_id, args.region, "get-failed", "failed")
    elif args.command == "running":
        return list_executions(args.client_id, args.region, "get-running", "running")
    elif args.command == "timed-out":
        return list_executions(
            args.client_id, args.region, "get-timed-out", "timed-out"
        )
    elif args.command == "aborted":
        return list_executions(args.client_id, args.region, "get-aborted", "aborted")
    elif args.command == "cloud-version":
        return cloud_version(client_id=args.client_id, region=args.region)
    else:
        print(f"Unknown command: {args.command}. Use --help for available commands.")
        return 1


if __name__ == "__main__":
    sys.exit(main())
