import datetime
import json
import logging
import os
import time
from typing import Any, Dict

import boto3
import botocore

# The Tahoe Twirp endpoint path
TAHOE_SERVICE = '/twirp/twitch.fulton.example.twitchtahoeapiservice.TwitchTahoeAPIService'

logging.getLogger().setLevel(logging.INFO)


def get_aws_session(role_session_name):
    """Return an AWS session to use the assumed Tahoe credentials."""
    sts = boto3.client('sts', region_name='us-west-2',
                       endpoint_url='https://sts.us-west-2.amazonaws.com')
    resp = sts.assume_role(
        # A role with access to invoke the Tahoe API lambda
        RoleArn=os.environ['REGISTRATION_ROLE'], 
        RoleSessionName=role_session_name)
    creds = resp['Credentials']
    return boto3.Session(
        aws_access_key_id=creds['AccessKeyId'],
        aws_secret_access_key=creds['SecretAccessKey'],
        aws_session_token=creds['SessionToken'])


def generate_tahoe_request(method: str, body: Dict[str, Any]) -> Dict[str, Any]:
    """Generate the payload of a Tahoe request."""
    return {
        'httpMethod': 'POST',
        'path': f'{TAHOE_SERVICE}/{method}',
        'headers': {'Content-Type': 'application/json'},
        'body': json.dumps(body),
    }


def retry_call_tahoe(session, payload: Dict[str, str], times: int = 4, interval: int = 30) -> Dict[str, Any]:
    caught_exception = None
    for i in range(0, times):
        try:
            logging.info("call_tahoe retry#%s", i)
            tahoe_response = call_tahoe(session, payload)
        except Exception as err:
            caught_exception = err
            time.sleep(interval)
            continue
        else:
            return tahoe_response

    logging.error("failed to call tahoe")
    raise caught_exception


def call_tahoe(session, payload: Dict[str, str]) -> Dict[str, Any]:
    """Invoke a Tahoe Lambda and return the parsed response."""
    logging.info("request payload: %s", payload)

    lambda_client = session.client('lambda')
    resp = lambda_client.invoke(FunctionName='TapLambdaFunction', Payload=json.dumps(payload))
    r_payload = resp['Payload'].read()

    if resp['StatusCode'] >= 300:
        raise RuntimeError(
            f'Response code {resp["StatusCode"]} from {payload["path"]}: {r_payload}')
    r_payload = json.loads(r_payload)
    logging.info("payload: %s", r_payload)

    if r_payload['statusCode'] >= 300:
        raise RuntimeError(
            f'Response code {r_payload["statusCode"]} from {payload["path"]}: {r_payload}')

    return json.loads(r_payload['body'])


def az_id_for_cluster(clusterId: str):
    clusters = boto3.client("redshift").describe_clusters(ClusterIdentifier=clusterId)['Clusters']
    if len(clusters) != 1:
        raise RuntimeError(
            f'Expected to find one redshift cluster with id={clusterId} but found {len(clusters)}')
    az = clusters[0]['AvailabilityZone']
    resp = boto3.client("ec2").describe_availability_zones(ZoneNames=[az])
    return resp["AvailabilityZones"][0]["ZoneId"]


def string_to_relation_schema(sub):
    return dict(zip(('schema', 'relation'), sub.split('.')))


def get_subscription_changes(expected_subs, current_subs):
    expected_subs = set(expected_subs)
    existing_subs = set()
    for sub in current_subs:
        existing_subs.add(f"{sub['target']['schema']}.{sub['target']['relation']}")
    to_unsubscribe = [string_to_relation_schema(x) for x in existing_subs - expected_subs]
    to_subscribe = [string_to_relation_schema(x) for x in expected_subs - existing_subs]
    return to_unsubscribe, to_subscribe


def get_subscriptions_for_tap(session, tap_id):
    request = generate_tahoe_request('ListSubscriptions', {
        'tap_ids': [tap_id],
    })
    resp = call_tahoe(session, request)
    return resp.get('subscriptions', [])


def unsubscribe_from_view(session, tap_id, view):
    request = generate_tahoe_request('Unsubscribe', {
        'tap_id': tap_id,
        'target': view
    })
    resp = call_tahoe(session, request)
    return resp


def subscribe_to_view(session, tap_id, view):
    request = generate_tahoe_request('Subscribe', {
        'tap_id': tap_id,
        'target': view
    })
    resp = call_tahoe(session, request)
    return resp


def subscription_handler(event, context, session):
    if event['RequestType'] in ['Create', 'Update']:
        tap_id = event['ResourceProperties']['TapId']
        subscriptions = get_subscriptions_for_tap(session, tap_id)
        to_unsubscribe, to_subscribe = get_subscription_changes(event['ResourceProperties']['Views'], subscriptions)
        for view in to_unsubscribe:
            logging.info(f'Attempting to unsubscribe from view {view}')
            unsubscribe_from_view(session, tap_id, view)
        for view in to_subscribe:
            logging.info(f'Attempting to subscribe to view {view}')
            subscribe_to_view(session, tap_id, view)
    return


def get_registration_response(tap_id):
    return {
        'PhysicalResourceId': tap_id,
        'Data': {
            'TapId': tap_id
        }
    }


def enable_redshift_audit_logging(rs_client, cluster_id: str, logging_bucket_name: str, s3_key_prefix: str):
    retry_times = 4
    retry_interval = 30
    caught_exception = None
    for i in range(0, retry_times):
        try:
            rs_client.enable_logging(
                ClusterIdentifier=cluster_id,
                BucketName=logging_bucket_name,
                S3KeyPrefix=s3_key_prefix
            )
        except Exception as err:
            caught_exception = err
            time.sleep(retry_interval)
            continue
        else:
            return

    logging.error("failed enable redshift audit logging")
    raise caught_exception


def registration_handler(event, context, session):
    logging_bucket_name = os.environ['LOGGING_BUCKET']
    properties = event['ResourceProperties']
    rs_client = boto3.client('redshift')

    if event['RequestType'] == 'Delete':
        rs_client.modify_cluster_iam_roles(
            ClusterIdentifier=properties['ClusterId'],
            RemoveIamRoles=[properties['TapDataRoleArn']],
        )
        request = generate_tahoe_request('DeregisterTap', {
            'tap_id': event['PhysicalResourceId']
        })
        call_tahoe(session, request)
        return

    if event['RequestType'] == 'Create':
        rs_client.modify_cluster_iam_roles(
            ClusterIdentifier=properties['ClusterId'],
            AddIamRoles=[properties['TapDataRoleArn']],
        )
        request = generate_tahoe_request('RegisterTap', {
            'tap_name': properties['TapName'],
            'bindle_id': properties['BindleId'],
            'purpose': properties['Purpose'],
            'stack_id': event['StackId'],
            'tap_admin_role': properties['TapAdminRoleArn'],
            'redshift_props': {
                'tap_data_role': properties['TapDataRoleArn'],
                'vpc_endpoint_service': properties['VpcEndpointService'],
                'az_id': az_id_for_cluster(properties['ClusterId']),
                'redshift_secret_arn': properties['RedshiftSecretArn'],
                'redshift_cluster_id': properties['RedshiftClusterId']
            },
        })

        response = retry_call_tahoe(session, request)
        enable_redshift_audit_logging(
            rs_client,
            properties['ClusterId'],
            logging_bucket_name,
            response['tap']['id'],
        )
        return get_registration_response(response['tap']['id'])

    if event['RequestType'] == 'Update':
        rs_client.modify_cluster_iam_roles(
            ClusterIdentifier=properties['ClusterId'],
            AddIamRoles=[properties['TapDataRoleArn']],
        )
        request = generate_tahoe_request('RegisterTap', {
            'tap_id': event['PhysicalResourceId'],
            'tap_name': properties['TapName'],
            'bindle_id': properties['BindleId'],
            'purpose': properties['Purpose'],
            'stack_id': event['StackId'],
            'tap_admin_role': properties['TapAdminRoleArn'],
            'redshift_props': {
                'tap_data_role': properties['TapDataRoleArn'],
                'vpc_endpoint_service': properties['VpcEndpointService'],
                'az_id': az_id_for_cluster(properties['ClusterId']),
                'redshift_secret_arn': properties['RedshiftSecretArn'],
                'redshift_cluster_id': properties['RedshiftClusterId']
            }
        })
        response = call_tahoe(session, request)
        return get_registration_response(response['tap']['id'])


def handler(event, context):
    """Entrypoint handler for the lambda function"""
    session = get_aws_session('tap-invocation')
    properties = event['ResourceProperties']
    if properties['CustomResource'] == 'SubscriptionList':
        return subscription_handler(event, context, session)
    if properties['CustomResource'] == 'Registration':
        return registration_handler(event, context, session)
    raise RuntimeError('No Valid CustomResource specified')


if __name__ == '__main__':
    # TODO we should have more unit/integration tests in lieu of this function
    import argparse
    import sys
    parser = argparse.ArgumentParser()
    parser.add_argument('--resource', required=True)
    args = parser.parse_args()
    if args.resource == 'registration':
        event = {
            'StackId': 'some-id',
            'ResourceProperties': {
                'TapName': 'tap-name',
                'CustomResource': 'Registration',
                'ClusterId': 'tahoetapconsumer-tahoe1812e036-1vk48wxx27kzd',
                'BindleId': 'amzn1.bindle.resource.ABCDEfghij1234567890',
                'Purpose': 'test purpose',
                'RedshiftSecretArn': 'test secret',
                'TapAdminRoleArn': 'arn:aws:iam::986713075947:root',
                'TapDataRoleArn': 'arn:aws:iam::986713075947:root',
                'VpcEndpointService': 'com.amazonaws.vpce.us-west-2.vpce-svc-0d1f57062dc48aecd',
            }
        }
    elif args.resource == 'subscriptions':
        event = {
            'StackId': 'some-id',
            'ResourceProperties': {
                'CustomResource': 'SubscriptionList',
                'TapId': 'dazwiogDmaeaDdBu',
                'Views': [
                    'spade.invariant_ping',
                    'spade.buffer-empty'
                ],
            },
            'RequestType': 'Update',
        }
    else:
        print(f'Invalid resource {args.resource}. Please select a valid resource.')
        sys.exit(1)

    context = {}
    handler(event, context)
