from autobahn.websocket import WebSocketClientFactory, WebSocketClientProtocol, \
  connectWS
from compiler.ast import For
from exceptions import RuntimeWarning
from mod_python import apache
from mod_pywebsocket import msgutil, handshake, common
from optparse import OptionParser
from threading import Condition, Thread, Timer
from twisted.internet import reactor, ssl
from twisted.python import log
from urlparse import parse_qs
import json
import logging.config
import os
import os.path
import sys
import time
import twisted.internet.error


try:
  home = os.path.expanduser('~')
  with open('%s/%s' % (home, 'proxy.json') , 'r') as fp:
    config = json.load(fp)
except Exception, ex:
  apache.log_error('Error loading websocket proxy configuration: %s' % str(ex))
  raise

try:
  if 'loggingConfiguration' not in config:
    msg = 'loggingConfiguration was missing from configuration'
    apache.log_error(msg)
    raise RuntimeWarning(msg)
  with open(config['loggingConfiguration'], 'rs') as fp:
    logging_config = json.load(fp)
    try:
      logging.config.dictConfig(logging_config)
    except AttributeError:
      import dictconfig
      dictconfig.dictConfig(logging_config)
except Exception, ex:
  apache.log_error('Error loading logging configuration: %s' % str(ex))
  raise

logger = logging.getLogger('root')

class ReactorThread(Thread):
  def run(self):
    logger.info('Starting reactor thread')
    reactor.run(installSignalHandlers=0)

  def stop(self):
    logger.info('Stopping reactor thread')
    reactor.stop()

class ConnectionReaper:

  def __init__(self, proxy, reap_interval):
    self._proxy = proxy
    self._reap_interval = reap_interval

  def _reap(self):
    self._proxy.reap_inactive_connections()

    self.timer = Timer(self._reap_interval, self._reap)
    self.timer.start()

  def start(self):
    logger.info('Starting connection reaper: [reapInterval=%s]' % self._reap_interval)
    self._reap()

  def stop(self):
    logger.info('Stopping connection reaper')
    if self.timer is not None:
      self.timer.cancel()

class ProxyConnection:
  def __init__(self, proxy, request, conn_string):
    self.proxy = proxy
    self._request = request
    self._ws_stream = request.ws_stream
    self._conn_string = conn_string
    self._client = None
    self._on_client_condition = Condition()

  def open(self):
    factory = RegisteringWebSocketClientFactory('ws://%s/ws' % self._conn_string,
                                                self.on_client)
    factory.protocol = ProxyClientProtocol

    if self.protocol:
      factory.setSessionParameters(factory.url, factory.origin, [self.protocol],
                                   factory.useragent)
    connector = connectWS(factory)
    connector.proxy_connection = self
    reactor.wakeUp()
    self.last_used_time = time.time()

  def get_client(self):
    if not self._client:
      self._on_client_condition.acquire()
      while not self._client:
        self._on_client_condition.wait(30)

      self._on_client_condition.release()

    return self._client

  def send_message(self, message, binary ):

    if isinstance(message, unicode):
      message = message.encode('utf-8')

    self.get_client().sendMessage(message, binary)
    self.last_used_time = time.time()

  def on_client(self, client):
    self._on_client_condition.acquire()
    self._client = client
    self._client.onMessage = self.on_message
    self._client.connectionLost = self.connection_lost
    self._on_client_condition.notify_all()
    self._on_client_condition.release()


  def on_message(self, message, isbinary):
    self._ws_stream.send_message(message, binary=isbinary)
    self.last_used_time = time.time()

  def connection_lost(self, reason):
    if reason.type  != twisted.internet.error.ConnectionDone:
      logger.error('Proxy connection %s lost: %s' % (self, reason))
      self._ws_stream.close_connection(common.STATUS_GOING_AWAY,
                                      'Proxy lost connection to backend.')

    self.proxy.connection_closed(self._request)

  def close_inward(self):
    self.get_client().sendClose()

  def close_outward(self, code, reason):
    try:
      self._ws_stream.close_connection(code, reason)
    except msgutil.ConnectionTerminatedException:
      # ignore errors on close ...
      pass

class ProxyClientProtocol(WebSocketClientProtocol):
  def onOpen(self):
    self.factory.register(self)

class RegisteringWebSocketClientFactory(WebSocketClientFactory):
  def __init__(self, url, on_client):
    self.on_client = on_client
    WebSocketClientFactory.__init__(self, url)

  def register(self, client):
    self.client = client
    self.on_client(client)

  def unregister(self, client):
    self.client = None

  def clientConnectionFailed(self, connector, reason):
    if reason.type == twisted.internet.error.ConnectionRefusedError:
      connector.proxy_connection.close_outward(common.STATUS_UNSUPPORTED,
                                               "Backend refused connection")

class WebsocketProxy():
  def __init__(self):
    self._reactor = None
    self._reaper = None
    self._shutting_down = False
    self._proxy_connection_map = {}

    # defaults are reap ever 5 minute with a 1 hour connection timeout
    self._reap_interval = 300
    self._connection_timeout = 3600

    if 'connectionReaper' in config:
      if 'reapInterval' in config['connectionReaper']:
        self._reap_interval = config['connectionReaper']['reapInterval']

      if 'connectionTimeout' in config['connectionReaper']:
        self._connection_timeout = config['connectionReaper']['connectionTimeout']

  def start(self):
    msg = 'Starting websocket proxy: [pid=%s, reapInterval=%s, connectionTimeout=%s]' % (str(os.getpid()), self._reap_interval, self._connection_timeout)
    logger.info(msg)
    apache.log_error(msg, apache.APLOG_NOTICE)

    self._start_reaper()
    self._start_reactor()

  def stop(self):
    msg = 'Stopping websocket proxy: [pid=%s]' % str(os.getpid())
    logger.info(msg)
    apache.log_error(msg, apache.APLOG_NOTICE)

    self._shutting_down = True
    self._stop_reactor()
    self._stop_reaper()

  def session_id_to_endpoint(self, session_id):
    if 'sessionMappingFile' not in config:
      msg = "No mapping file configured"
      logger.error()
      raise Exception(msg)

    mapping_file = config['sessionMappingFile']
    with open(mapping_file, 'r') as fp:
      mapping = json.load(fp)

    endpoint = mapping[session_id]

    if not endpoint:
      msg = "No entry for %s" % session_id
      logger.error(msg)
      raise Exception(msg)

    if 'host' not in endpoint or 'port' not in endpoint:
      msg = "Invalid endpoint  %s" % endpoint
      logger.error(msg)
      raise Exception(msg)

    connection_str = "%s:%s" % (endpoint['host'], endpoint['port'])

    logger.info("Mapping session %s to %s" % (session_id, connection_str))

    return connection_str

  def lookup_endpoint(self, request):
    qs = parse_qs(request.args, False, True)
    if 'connection' in qs:
      connection = qs['connection'][0]
    elif 'sessionId' in qs:
      connection = self.session_id_to_endpoint(qs['sessionId'][0])

    return connection

  def handle_connection(self, request):
    connection = self.lookup_endpoint(request)

    proxy_conn = ProxyConnection(self, request, connection)
    proxy_conn.protocol = request.ws_protocol
    proxy_conn.open()
    self._proxy_connection_map[request] = proxy_conn

    try:
      while not self._shutting_down:
        message = request.ws_stream.receive_message()
        if not message:
          logger.info('web_socket_transfer_data thread returning')
          return

        proxy_conn.send_message(message, not isinstance(message, unicode))
    except msgutil.ConnectionTerminatedException:
      logger.info('web_socket_transfer_data thread connection terminated')

    proxy_conn.close_inward();
    self.connection_closed(request)

  def handle_connection_close(self, request):
    self._proxy_connection_map[request].close_inward()
    del self._proxy_connection_map[request]

  def connection_closed(self, request):
    if request in self._proxy_connection_map:
      del self._proxy_connection_map[request]

  def _start_reaper(self):
    self._connection_reaper = ConnectionReaper(self, self._reap_interval)
    self._connection_reaper.start()

  def _stop_reaper(self):
    if self._connection_reaper is not None:
      self._connection_reaper.stop()

  def _start_reactor(self):
    self._reactor = ReactorThread()
    self._reactor.start()

  def _stop_reactor(self):
    if self._reactor is not None:
      self._reactor.stop()

  def reap_inactive_connections(self):
    to_delete = []

    for request, proxy_conn in self._proxy_connection_map.iteritems():
      if (time.time() - proxy_conn.last_used_time) > self._connection_timeout:

        # First close connection to client
        proxy_conn.close_outward(common.STATUS_GOING_AWAY,
                                 'Idle connection reaped')
        # now close connection to backend
        proxy_conn.close_inward();
        to_delete.append(request)

    # Now remove reaped connections for the map
    for request in to_delete:
      logger.info('Reaping connection: %s' % self._proxy_connection_map[request])
      del self._proxy_connection_map[request]

# entry point
proxy = WebsocketProxy()

def cleanup(data):
  proxy.stop()
  logging.shutdown()

# register cleanup callback
apache.register_cleanup(cleanup)
proxy.start()

# pywebsocket functions
def web_socket_do_extra_handshake(request):
  # choice wamp if given the option
  if request.ws_requested_protocols and 'wamp' in request.ws_requested_protocols:
    request.ws_protocol = 'wamp'

def web_socket_transfer_data(request):
  try:
    proxy.handle_connection(request)
  except:
    logger.exception('Exception in web_socket_transfer_data')

def web_socket_passive_closing_handshake(request):
  try:
    proxy.handle_connection_close(request)
  except:
    logger.exception('Exception in web_socket_passive_closing_handshake')


