import os
import time
import re
import traceback
import threading
import logging
import subprocess
import argparse
import ipaddress

from flask import Flask, request, jsonify
from handlers import KasmLogHandler
from log_entry_pb2 import LogEntry
from datetime import datetime
from google.protobuf.message import DecodeError
from logging.handlers import RotatingFileHandler


log_path = "/var/log/kasm-logger"
os.makedirs(log_path, exist_ok=True)
os.makedirs("/run/docker/plugins", exist_ok=True)


app = Flask(__name__)
app_logger = logging.getLogger('kasmlogger')


def parse_container_data(data):
    container_info = data.get("Info", {})

    file = data.get("File")
    container_id = container_info.get("ContainerID")
    container_image = container_info.get("ContainerImageName")

    container_env = {k: v for env in container_info.get("ContainerEnv", []) if "=" in env for k, v in [env.split("=", 1)]}

    return container_id, container_image, container_env, file


def resolve_container_hostname(container_id, hostname):
    api_address = None
    cmd = f"docker exec {container_id} getent hosts {hostname} | awk '{{print $1}}'"

    try:
        api_address = str(ipaddress.ip_address(hostname))
    except ValueError:
        pass

    try:
        api_address = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT).decode("utf-8").strip()
        app_logger.info(f"resolved {hostname} to {api_address} for container({container_id})")
    except subprocess.CalledProcessError as exception:
        app_logger.error(f"failed to resolve hostname({hostname}) for container({container_id}): {exception}")

    return api_address


def consume_container_logs(container_id, container_image, container_env, file_path):


    try:
        container_logger = logging.getLogger(f"container-{container_id}")
        container_logger.setLevel(logging.INFO)

        file_handler = RotatingFileHandler(
            filename=f"{log_path}/container-{container_id}.log",
            maxBytes=10 * 1024 * 1024,
            backupCount=5
        )
        file_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
        container_logger.addHandler(file_handler)

        api_hostname = container_env.get("KASM_API_HOST")
        if not api_hostname:
            raise ValueError("Missing KASM_API_HOST in container environment")
        api_address = resolve_container_hostname(container_id, api_hostname)

        if not api_address:
            raise ValueError("Failed to resolve KASM_API_HOST")

        config = {
            'log_path': "/api/kasm_session_log",
            'manager': {
                'hostnames': [api_address],
                'public_port': container_env["KASM_API_PORT"]
            }
        }

        kasm_handler = KasmLogHandler.create_cached_kasmloghandler(
            application="Session",
            kasm_api_jwt=container_env["KASM_API_JWT"],
            config=config,
            hostname=container_env["KASM_ID"]
        )
        container_logger.addHandler(kasm_handler)

        app_logger.info(f"starting logging for container({container_id})")

        with open(file_path, "rb") as fifo:
            while True:
                message_size_bytes = fifo.read(4)
                if not message_size_bytes or len(message_size_bytes) < 4:
                    break

                message_size = int.from_bytes(message_size_bytes, byteorder="big")
                message_bytes = fifo.read(message_size)
                entry = LogEntry()
                try:
                    entry.ParseFromString(message_bytes)
                except Exception:
                    container_logger.exception(f"unexpected error while parsing log entry for container({container_id})")
                    continue

                channel = None
                line = entry.line.decode("utf-8")
                channel_match = re.match(r'^<KASM\|(\d+)>\s*(.*)', line)
                if channel_match:
                    channel = channel_match.group(1)
                    line = channel_match.group(2)

                container_logger.info(line, extra={
                    'channel': channel,
                    'timestamp': entry.time_nano,
                    'container_id': container_id,
                    'container_image': container_image
                })

        app_logger.info(f"stopping logging for container({container_id})")
    except FileNotFoundError:
        pass
    except Exception:
        app_logger.exception(f"error processing logs for container({container_id})")


@app.route("/Plugin.Activate", methods=["POST"])
def plugin_activate():
    app_logger.info("Plugin.Activate endpoint called")
    return jsonify({"Implements": ["LogDriver"]})


@app.route("/LogDriver.Capabilities", methods=["POST"])
def capabilities():
    return jsonify({"Cap": {"ReadLogs": True}})


@app.route("/LogDriver.ReadLogs", methods=["POST"])
def read_logs():
    data = request.get_json(force=True)
    follow = data.get("Config", {}).get("Follow", False)

    container_id, container_image, container_env, file = parse_container_data(data)

    try:
        def generate_logs():
            with open(f"{log_path}/container-{container_id}.log", "r") as f:
                while True:
                    line = f.readline()
                    if line:
                        entry = LogEntry()
                        entry.time_nano = int(datetime.fromisoformat(line.split(" ")[0]).timestamp() * 1e9)
                        entry.line = line.encode("utf-8")
                        entry.source = "stdout"
                        entry.partial = False
                        payload = entry.SerializeToString()
                        payload_size = len(payload).to_bytes(4, byteorder="big")
                        yield payload_size + payload
                    elif follow:
                        position = f.tell()
                        f.seek(position)
                        time.sleep(0.1)
                    else:
                        break

        return generate_logs(), 200, {"Content-Type": "application/x-protobuf"}

    except FileNotFoundError:
        return jsonify({"Error": f"No logs found for container {container_id}"}), 404
    except Exception as exception:
        app_logger.exception("error reading logs.")
        return jsonify({"Error": f"Error reading logs: {str(exception)}"}), 500


@app.route("/LogDriver.StartLogging", methods=["POST"])
def start_logging():
    data = request.get_json(force=True)

    app_logger.info(f"container data:\n{data}")

    container_id, container_image, container_env, file = parse_container_data(data)

    kasm_id = container_env.get("KASM_ID", None)
    kasm_api_host = container_env.get("KASM_API_HOST", None)
    kasm_api_jwt = container_env.get("KASM_API_JWT", None)
    kasm_api_port = container_env.get("KASM_API_PORT", None)

    if not kasm_id or not kasm_api_host or not kasm_api_jwt or not kasm_api_port:
        return jsonify({"Error": "Missing KASM environment variables"}), 400

    thread = threading.Thread(target=consume_container_logs, args=(container_id, container_image, container_env, file), daemon=True)
    thread.start()

    return jsonify({}), 200


@app.route("/LogDriver.StopLogging", methods=["POST"])
def stop_logging():
    return jsonify({}), 200


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", type=int, default=0)
    args = parser.parse_args()

    driver_log_file = os.path.join(log_path, "logging-driver.log")
    rotating_handler = RotatingFileHandler(
        filename=driver_log_file,
        maxBytes=10 * 1024 * 1024,
        backupCount=5
    )
    rotating_handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    rotating_handler.setFormatter(formatter)

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.DEBUG)
    root_logger.addHandler(rotating_handler)

    app.run(host="unix:///run/docker/plugins/kasmlogger.sock", threaded=True, debug=bool(args.debug))