import urllib.request
import ssl
import socket
import logging
import queue
from pythonjsonlogger import jsonlogger
import json
import datetime
import threading

from logging.handlers import QueueHandler, QueueListener
from logging import StreamHandler


app_logger = logging.getLogger('kasmlogger')


class KasmLogHandler(StreamHandler):
    HTTP_CACHE_SIZE = 20  # max number of logs to cache for http transfer
    DB_CACHE_SIZE = 50  # max number of logs to cache for the db
    MAX_CACHE_SECONDS = 120  # max time to cache logs

    def __init__(self, application, kasm_api_jwt, db=None, config=None, hostname=None, max_cache_size=HTTP_CACHE_SIZE):
        StreamHandler.__init__(self)
        self.db = db
        self.config = config
        self.log_format = 'standard'
        self.server_id = None
        self.kasm_api_jwt = kasm_api_jwt
        self.max_cache_size = max_cache_size

        if db is None and config is None:
            raise ValueError('KasmLogHandler initialized with null database and configuration values')
        elif config is not None:
            if 'manager' not in config or 'hostnames' not in config['manager'] or len(
                    config['manager']['hostnames']) == 0:
                raise ValueError('Invalid configuration file.')

        self.update_config()

        self.http_cache = []
        self.http_cache_last_flush = datetime.datetime.now()
        self.db_log_cache = []
        self.db_log_cache_last_flush = datetime.datetime.now()
        self.setFormatter(jsonlogger.JsonFormatter(fmt="%(asctime) %(name) %(processName) %(filename) %(funcName) "
                                                       "%(levelname) %(lineno) %(module) %(threadName) %(message)",
                                                   timestamp=True))
        if hostname is None:
            self.hostname = socket.gethostname()
        else:
            self.hostname = hostname
        self.application = application

    def flush(self):
        if self.log_protocol == 'https' and len(self.http_cache) > 0:
            if self.log_format == 'splunk':
                # splunk uses none standard json for sending multiple logs
                msg = " ".join(json.dumps(log) for log in self.http_cache)
            else:
                msg = json.dumps(self.http_cache)

            self.async_send_http_logs(msg)
            self.http_cache_last_flush = datetime.datetime.now()
            self.http_cache.clear()

    def update_config(self, config=None):
        # initially called with no parameters
        # pass config to update

        if self.db is not None:
            self.log_protocol = self.db.getConfigSetting('logging', 'log_protocol', default='internal').value.lower()
            self.log_port = self.db.getConfigSetting('logging', 'log_port', default=443).value
            self.log_host = self.db.getConfigSetting('logging', 'log_host', default='').value
            self.hec_token = self.db.getConfigSetting('logging', 'hec_token', default='').value
            self.http_method = self.db.getConfigSetting('logging', 'http_method', default='post').value.lower()
            self.https_insecure = self.db.getConfigSetting('logging', 'https_insecure', default='true').getValueBool()
            self.splunk_endpoint = self.db.getConfigSetting('logging', 'url_endpoint', default='/').value

            # splunk is still over https, but the format of the log needs to change
            if self.log_protocol == 'splunk':
                self.log_format = 'splunk'
                self.log_protocol = 'https'

            # If log host is empty do not log via https
            if self.log_host == '':
                self.log_protocol = 'internal'

        if self.config is not None:
            # is this an update
            if config is not None:
                if 'manager' not in self.config or 'hostnames' not in self.config['manager'] or len(
                        self.config['manager']['hostnames']) == 0:
                    raise ValueError('Invalid configuration file.')

                # are there any changes in the passed config
                if (self.server_id == config['agent']['server_id'] and
                        self.hec_token == (config['manager']['token'] if 'token' in config['manager'] else 'None') and
                        self.log_port == (
                        config['manager']['public_port'] if 'public_port' in config['manager'] else '443') and
                        self.log_host == config['manager']['hostnames'][0]):
                    return

                self.config = config

            self.log_host = self.config['manager']['hostnames'][0]
            # uses https regardless if a conf was passed instead of a db
            self.log_protocol = 'https'
            self.log_port = self.config['manager']['public_port'] if 'public_port' in self.config['manager'] else '443'
            self.http_method = 'post'
            self.hec_token = self.config['manager']['token'] if 'token' in self.config['manager'] else 'None'
            # TODO: The client cert is provided in the configuration
            self.https_insecure = True
            self.splunk_endpoint = self.config.get('log_path', "/api/kasm_session_log")
            if 'agent' in self.config and 'server_id' in self.config['agent']:
                self.server_id = self.config['agent']['server_id']

        if self.log_protocol == 'https':
            if self.log_port == 0 or self.log_port is None:
                raise ValueError('Invalid log port')

            self.destination = f"https://{self.log_host}:{self.log_port}{self.splunk_endpoint}?token={self.kasm_api_jwt}"
            self.request = urllib.request.Request(self.destination, method=self.http_method.upper())
            self.request.add_header('Content-Type', 'application/json')
            if self.log_format == 'splunk':
                self.request.add_header('Authorization', f"Splunk {self.hec_token}")
            self.insecure_context = ssl.create_default_context()
            self.insecure_context.check_hostname = False
            self.insecure_context.verify_mode = (not self.https_insecure)

    @staticmethod
    def create_cached_kasmloghandler(application, kasm_api_jwt, db=None, config=None, hostname=None,
                                     max_cache_size=HTTP_CACHE_SIZE):
        qu = queue.Queue(100)
        qu_handler = QueueHandler(qu)
        kl = KasmLogHandler(application=application, kasm_api_jwt=kasm_api_jwt, db=db, config=config, hostname=hostname,
                            max_cache_size=max_cache_size)
        qu_listener = QueueListener(qu, kl)
        qu_listener.start()
        qu_handler.log_handler = kl
        return qu_handler

    def emit_http(self, log_dict):
        # lets send x number of logs per request
        if self.log_format == 'standard':
            log_dict['ingest_date'] = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
            log_dict['host'] = self.hostname
            self.http_cache.append(log_dict)
        elif self.log_format == 'splunk':
            slog = {"event": log_dict, "time": datetime.datetime.utcnow().timestamp(), "host": self.hostname}
            self.http_cache.append(slog)

        t_d = datetime.datetime.now() - self.http_cache_last_flush

        if len(self.http_cache) > self.max_cache_size or t_d.total_seconds() > KasmLogHandler.MAX_CACHE_SECONDS:
            if self.log_format == 'splunk':
                # splunk uses none standard json for sending multiple logs
                msg = " ".join(json.dumps(log) for log in self.http_cache)
            else:
                msg = json.dumps(self.http_cache)

            threading.Thread(target=self.async_send_http_logs, args=(msg,)).start()

            self.http_cache_last_flush = datetime.datetime.now()
            self.http_cache.clear()

    def async_send_http_logs(self, msg):
        try:
            encoded_msg = msg.encode('utf-8')

            if not self.https_insecure:
                response = urllib.request.urlopen(self.request, data=encoded_msg, timeout=3)
            else:
                response = urllib.request.urlopen(self.request, data=encoded_msg, timeout=3, context=self.insecure_context)

            if response.status > 299:
                app_logger.error(f"HTTP logging failed: Invalid response code({response.status})")
        except Exception:
            app_logger.exception("HTTP logging failed")

    def emit_db(self, log_dict, forwarded=False):
        if forwarded and 'host' in log_dict and 'ingest_date' in log_dict:
            log = {'host': log_dict.pop('host', self.hostname), 'ingest_date': log_dict.pop('ingest_date',
                                                                                            datetime.datetime.utcnow().strftime(
                                                                                                "%Y-%m-%dT%H:%M:%SZ")),
                   'data': log_dict}
        else:
            log = {'host': self.hostname, 'data': log_dict}
            log['ingest_date'] = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")

        t_d = datetime.datetime.now() - self.db_log_cache_last_flush

        # look for specifc data and add to columns in log record
        log['metric_name'] = log_dict.get('metric_name')
        log['kasm_user_name'] = log_dict.get('kasm_user_name')
        log['levelname'] = log_dict.get('levelname')
        log['disk_stats'] = None
        log['memory_stats'] = None
        log['cpu_percent'] = None
        log['server_id'] = None
        if 'heartbeat' in log_dict:
            hb = log_dict['heartbeat']
            if 'disk_stats' in hb:
                log['disk_stats'] = hb['disk_stats'].get('percent')
            if 'memory_stats' in hb:
                log['memory_stats'] = hb['memory_stats'].get('percent')
            log['cpu_percent'] = hb.get('cpu_percent')
            log['server_id'] = hb.get('server_id')

        lock = threading.Lock()
        lock.acquire()
        self.db_log_cache.append(log)

        if len(self.db_log_cache) > KasmLogHandler.DB_CACHE_SIZE or t_d.total_seconds() > KasmLogHandler.MAX_CACHE_SECONDS:
            try:
                self.db.createLogs(self.db_log_cache)
            except Exception:
                app_logger.exception("database logging failed")
            finally:
                self.db_log_cache.clear()
                self.db_log_cache_last_flush = datetime.datetime.now()

        lock.release()

    def emit(self, record):
        # forwarded logs from agent will container _json field
        if hasattr(record, '_json'):
            logs = json.loads(record._json)

            if self.db is not None:
                # logs are passed in bulk as array of logs
                for log in logs:
                    if self.db is not None:
                        self.emit_db(log, forwarded=True)
                    if self.log_protocol == 'https':
                        self.emit_http(log)

        else:
            record.application = self.application
            msg = self.format(record)
            log_json = json.loads(msg)

            if self.server_id is not None:
                log_json['server_id'] = self.server_id

            if self.db is not None:
                self.emit_db(log_json)
            if self.log_protocol == 'https':
                self.emit_http(log_json)


class ForwardedLogFilter(logging.Filter):
    """
    Filter out logs that have been forwarded to this system
    """

    def filter(self, record):
        return not hasattr(record, '_json')