import os
import abc
import shutil
import argparse
import re
import xml.etree.ElementTree
from pprint import pprint

import common.source_tools
import common.configuration
import common.file_utils


class LanguageBindingGeneratorBase(object):

    __metaclass__ = abc.ABCMeta

    def __init__(self, config):
        self.config = config # common.configuration.ConfigSettings
        self.extracted_type_info = None
        self.type_map = {}

    @abc.abstractmethod
    def get_language(self):
        pass

    @abc.abstractmethod
    def get_template_dir(self):
        pass

    @abc.abstractmethod
    def infer_target_type_name(self, native_type_name):
        pass

    @abc.abstractmethod
    def create_enum_generator(self, type_mapping, type_info):
        pass

    #@abc.abstractmethod
    #def create_struct_generator(self, type_mapping, type_info):
    #    pass

    #@abc.abstractmethod
    #def create_class_generator(self, type_mapping, type_info):
    #    pass

    #@abc.abstractmethod
    #def create_interface_generator(self, type_mapping, type_info):
    #    pass

    #@abc.abstractmethod
    #def create_interface_adapter_generator(self, type_mapping, type_info):
    #    pass

    def get_default_native_namespace(self):
        """ Returns the namespace to use for generated native code. """
        return '::ttv::binding::' + self.get_language()

    def load_typemap(self, typemap_xml_path):

        # Read the file and parse as XML
        doc = xml.etree.ElementTree.parse(typemap_xml_path)
        xml_root = doc.getroot()

        # Format the loaded types so they conform to a single space between tokens
        for xml_Mapping in xml_root.findall('Mapping'):

            key = xml_Mapping.attrib['NativeType']
            value = xml_Mapping.attrib['TargetType']

            type_name = key.strip().replace('\t', ' ')

            # Space tokens by one space
            i = 0
            token_type = None
            while i < len(type_name):
                ch = type_name[i]
                if ch == '*' or ch == '&':
                    char_type = 'ref'
                else:
                    char_type = 'word'

                if token_type:
                    if token_type != char_type:
                        type_name = type_name[:i] + ' ' + type_name[i:]
                        i = i + 1

                        if char_type == 'ref':
                            token_type = '_'
                        else:
                            token_type = char_type

                else:
                    token_type = char_type

                i = i + 1

            # Remove redundant whitespace
            last = type_name
            replaced = None
            while replaced != last:
                if replaced:
                    last = replaced
                replaced = last.replace('  ', ' ')
            type_name = replaced

            self.type_map[type_name] = value


    def is_output_parameter(self, native_type_name):
        """Determines if the given type represents an output parameter."""
        tokens = native_type_name.split()
        return tokens[-1] == '&'

    def is_pointer_type(self, native_type_name):
        tokens = native_type_name.split()
        if tokens[-1] == '*':
            return True
        elif len(tokens) > 1 and tokens[-2] == '*':
            return True
        return False

    def is_array_member(self, member_info, owning_type_info):
        """Returns the name of the member that holds the length, or a number describing the fixed length."""

        # Constant sized array
        if member_info.fixed_array_size is not None:
            return member_info.fixed_array_size

        # Otherwise, array members are pointers
        tokens = member_info.type_name.split()
        if tokens[-1] != '*':
            return None

        # Array members and with 'Array' and there is a corresponding element that ends with 'ArrayLength'
        if not member_info.name.endswith('Array'):
            return None

        length_member_name = member_info.name + 'Length'
        if not owning_type_info is None:
            if owning_type_info.find_member(length_member_name):
                return length_member_name

        return None

    def is_bitfield_member(self, type_info, native_type_name, member_name):
        # Currently, bitfield members have names that end with 'Bitfield'
        return member_name.endswith('Bitfield')


    def strip_rightmost_array_modifier(self, native_type_name):
        """Removes the right-most * from the type if it exists."""
        tokens = native_type_name.split()[::-1] # Reversed token list

        # Rightmost *
        if '*' in tokens:
            index = tokens.index('*')
            has_const = (index < len(tokens) - 1) and tokens[index + 1] == 'const'
            tokens.reverse()
            del tokens[len(tokens) - index - 1]
            # Remove the redundant const now that the * is gone
            if has_const:
                del tokens[len(tokens) - index - 1]
            return ' '.join(tokens)

        # Remove [.*]
        index = native_type_name.find('[')
        if index < 0:
            return native_type_name

        result = re.match('.+\[([0-9].)\].*', native_type_name)
        return native_type_name[:result.start(1)-1] + native_type_name[result.end(1)+1:]


    def strip_native_type_modifiers(self, native_type_name, remove_const=False, remove_volatile=False, remove_star=False, remove_ref=False):

        remove_all = not remove_const and not remove_volatile and not remove_star and not remove_ref

        if remove_all or remove_const: native_type_name = native_type_name.replace('const', '')
        if remove_all or remove_volatile: native_type_name = native_type_name.replace('volatile', '')
        if remove_all or remove_star: native_type_name = native_type_name.replace('*', '')
        if remove_all or remove_ref: native_type_name = native_type_name.replace('&', '')

        native_type_name = native_type_name.replace('  ', ' ')
        native_type_name = native_type_name.strip()

        return native_type_name;


    def find_mapped_type(self, native_type_name):

        """
        Given a native type an attempt is made to find the language binding type to use to represent
        it in the target language.
        """

        # All types should start in the global namespace, even built-in primitives
        if (not native_type_name in self.type_map) and (not native_type_name.startswith('::')):
            native_type_name = '::' + native_type_name

        def lookup(native_type_name):
            t = native_type_name

            # First try the type as is
            if t in self.type_map:
                #print('A ' + native_type_name + ' --> ' + t + ' --> ' + self.type_map[t])
                return self.type_map[t]

            # Strip out any fixed length array stuff ::char[64] -> ::char[]
            result = re.match('.+\[([0-9].)\].*', t)
            if result:
                #print('Found a fixed length array')
                t = t[:result.start(1)] + t[result.end(1):]

            # Remove const and volatile
            t = self.strip_native_type_modifiers(t, remove_const=True, remove_volatile=True)
            if t in self.type_map:
                #print('B ' + native_type_name + ' --> ' + t + ' --> ' + self.type_map[t])
                return self.type_map[t]

            # Remove &
            t = self.strip_native_type_modifiers(t, remove_ref=True)
            if t in self.type_map:
                #print('C ' + native_type_name + ' --> ' + t + ' --> ' + self.type_map[t])
                return self.type_map[t]

            # Remove one * at a time
            last = None
            while last != t:
                last = t
                t = self.strip_rightmost_array_modifier(t)
                if t in self.type_map:
                    #print('D ' + native_type_name + ' --> ' + t + ' --> ' + self.type_map[t])
                    return self.type_map[t]

            return None

        last = None
        while last != native_type_name:
            target_type_name = lookup(native_type_name)
            if target_type_name:
                return target_type_name
            last = native_type_name
            native_type_name = self.resolve_typedefs(native_type_name, recurse=False)

        raise Exception('Type mapping not found: ' + native_type_name)


    def resolve_typedefs(self, native_type_name, recurse=True):

        """
        Given a native type an attempt is made to find a simpler system type to represent it.
        For example, if '::UserId' is the input type, '::uint32_t' will be returned.
        """

        # Split on spaces
        tokens = native_type_name.split()

        ignores = [
            'const',
            'volatile'
            '*'
            '&'
        ]

        # Find the index of the token which is the actual base type
        t = filter(lambda x: x not in ignores, tokens)[0]
        index = tokens.index(t)

        # Resolve as far as possible
        while t in self.extracted_type_info.all:
            mapped = self.extracted_type_info.all[t]
            if isinstance(mapped, common.source_tools.CppTypedef):
                t = mapped.aliased_type
            else:
                t = mapped.type_name
                break

            if not recurse:
                break

        # Reconstruct the full type name
        tokens[index] = t

        t = ' '.join(tokens)

        # if native_type_name != t:
        #     print('Resolved typedef: ' + native_type_name + ' --> ' + t)

        return t


    def instantiate_template(self, generator, template_name, output_path):

        # Create the output directory if needed
        output_dir = os.path.dirname(output_path)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # Read the template file as lines
        template_path = self.get_template_dir() + '/' + template_name
        lines = common.file_utils.read_file_lines(template_path)

        # Fill it in
        lines = generator.fill_template(lines)

        return lines


    def generate_common(self, generator):

        template_names = generator.get_template_names()

        # Determine the output paths if not explicitly set
        if generator.type_mapping.output_paths is None:
            generator.type_mapping.output_paths = []
            for template_name in template_names:
                generator.type_mapping.output_paths.append( generator.type_mapping.context.output_dir + '/' + generator.infer_output_file_path(template_name) )

        if len(generator.type_mapping.output_paths) != len(template_names):
            raise Exception("Number of output files don't match the specified output paths")

        # Update the file for each template
        for i in range(0, len(template_names)):

            template_name = template_names[i]
            output_path = generator.type_mapping.output_paths[i]

            # Recreate the file from template create a new one from template
            print('Creating ' + output_path + ' from template: ' + template_name)
            lines = self.instantiate_template(generator, template_name, output_path)

            # Call the function that updates the relevant parts of the file
            lines = generator.update(template_name, lines)

            # Write it back out
            common.file_utils.write_file_lines(output_path, lines)


    def generate_enumeration(self, type_mapping, type_info):

        print("Processing enum: " + type_mapping.native_type_name + " --> " + type_mapping.target_type_name + "...")

        generator = self.create_enum_generator(type_mapping, type_info)
        self.generate_common(generator)

        print("Done processing enum: " + type_mapping.native_type_name)


    def generate_struct(self, type_mapping, type_info):

        print("Processing struct: " + type_mapping.native_type_name + " --> " + type_mapping.target_type_name + "...")

        generator = self.create_struct_generator(type_mapping, type_info)
        self.generate_common(generator)

        print("Done processing struct: " + type_mapping.native_type_name)


    def generate_interface(self, type_mapping, type_info):

        print("Processing interface: " + type_mapping.native_type_name + " --> " + type_mapping.target_type_name + "...")

        generator = self.create_interface_generator(type_mapping, type_info)
        self.generate_common(generator)

        print("Done processing interface: " + type_mapping.native_type_name)


    def generate_class(self, type_mapping, type_info):

        print("Processing class: " + type_mapping.native_type_name + " --> " + type_mapping.target_type_name + "...")

        generator = self.create_class_generator(type_mapping, type_info)
        self.generate_common(generator)

        print("Done processing class: " + type_mapping.native_type_name)


    def generate_interface_adapter(self, type_mapping, type_info):

        print("Processing interface adapter: " + type_mapping.native_type_name + "...")

        generator = self.create_interface_adapter_generator(type_mapping, type_info)
        self.generate_common(generator)

        print("Done processing interface adapter: " + type_mapping.native_type_name)


    def simplify_template_type_info(self, type_info, update_lookup):
        if type_info.template_info is None:
            return

        # This is hardcoded logic for certain templates
        useful_params_map = {
            '::std::vector': 1,
            '::std::map': 2,
            '::std::unordered_map': 2,
            '::std::set': 1,
            '::std::unordered_set': 1,
            '::std::shared_ptr': 1,
            '::std::unique_ptr': 1
        }

        if not type_info.template_info.template_name in useful_params_map:
            raise Exception('Unhandled template: ' + type_info.type_name)

        num_useful_params = useful_params_map[type_info.template_info.template_name]

        type_info.template_info.template_args = type_info.template_info.template_args[0:num_useful_params]

        # TODO: We don't handle nested templates well here.  They'd better be string arguments.

        # Resolve and simplify parameters
        new_name = type_info.template_info.template_name + '<'
        for i in range(0, num_useful_params):
            cleaned_up = self.resolve_typedefs(type_info.template_info.template_args[i])

            # Trim off nested template stuff
            lindex = cleaned_up.find('<')
            rindex = cleaned_up.rfind('>')

            if lindex > 0 and rindex > 0 and lindex < rindex:
                cleaned_up = cleaned_up[:lindex]

            #print(type_info.template_info.template_args[i] + '---->>>>' + cleaned_up)

            if not cleaned_up.startswith('::'):
                cleaned_up = '::' + cleaned_up

            type_info.template_info.template_args[i] = cleaned_up

            if i > 0:
                new_name = new_name + ', '
            new_name = new_name + cleaned_up

        new_name = new_name + '>'

        type_info.type_name = new_name

        # Update the global type lookup table to add the simplified type name
        if update_lookup:
            self.simplify_template_type_info.all[type_info.native_type_name] = type_info


    def simplify_template_types(self):

        print('Simplifying template types...')

        # Simplify template fields and arguments
        for klass in self.extracted_type_info.classes:
            self.simplify_template_type_info(klass, update_lookup=True)
            for member in klass.members:
                self.simplify_template_type_info(member, update_lookup=False)
            for method in klass.methods:
                if method.return_type:
                    self.simplify_template_type_info(method.return_type, update_lookup=False)
                for arg in method.arguments:
                    self.simplify_template_type_info(arg, update_lookup=False)

        print('Done simplifying template types.')


    def infer_default_native_namespace(self):

        """ Determines the namespace to put native source files in. """

        for context in self.config.contexts:
            if not context.default_native_namespace:
                context.default_native_namespace = self.get_default_native_namespace()


    def infer_target_type_names(self):

        """ Determines the target language type name for types that do not have it explicitly set via configuration. """

        def infer_short_name(native_type_name):
            index = native_type_name.rfind('::')
            if index >= 0:
                name = native_type_name[index+2:]
            else:
                name = native_type_name

            if name.startswith('TTV_'):
                name = name[len('TTV_'):]

            return name

        for context in self.config.contexts:

            # Enuerations
            for x in context.enumerations:
                if not x.target_type_name:
                    x.target_type_name = self.infer_target_type_name(x.native_type_name)

            ## Constants
            #for x in context.constants:
            #    if not x.target_type_name: x.target_type_name = context.default_binding_namespace + self.config.default_binding_namespace_delimiter + infer_short_name(x.native_type_name)

            ## Class proxies
            #for x in context.class_proxies:
            #    if not x.target_type_name: x.target_type_name = context.default_binding_namespace + self.config.default_binding_namespace_delimiter + infer_short_name(x.native_type_name)

            ## Simple classes
            #for x in context.simple_classes:
            #    if not x.target_type_name: x.target_type_name = context.default_binding_namespace + self.config.default_binding_namespace_delimiter + infer_short_name(x.native_type_name)

            ## Interface proxies
            #for x in context.interface_proxies:
            #    if not x.target_type_name: x.target_type_name = context.default_binding_namespace + self.config.default_binding_namespace_delimiter + infer_short_name(x.native_type_name)


    def process_config(self):

        """
        This is the main entrypoint for the binding generation.
        - Loads typemaps
        - Setup mappings between native and target language types
        - Extract native type information from C++ source files
        - Generate the target language source files and any native code needed to back them
        """

        # Load the primitives type map for the target language
        self.load_typemap(self.get_type_map_path())

        # Make sure the native namespace is set
        self.infer_default_native_namespace()

        # Infer missing target type names
        self.infer_target_type_names()

        # Print the configuration
        self.config.print_config()

        # Gather up the names of all the native types we want to extract
        params = common.source_tools.CastXmlConfig()

        for context in self.config.contexts:

            for x in context.enumerations:
                if x.native_type_name not in params.enumeration_names:
                    params.enumeration_names.append(x.native_type_name)

            for x in context.simple_classes:
                if x.native_type_name not in params.class_names:
                    params.class_names.append(x.native_type_name)

            for x in context.class_proxies:
                if x.native_type_name not in params.class_names:
                    params.class_names.append(x.native_type_name)

            for x in context.interface_proxies:
                if x.native_type_name not in params.class_names:
                    params.class_names.append(x.native_type_name)

            for x in context.namespaces:
                if x not in params.namespace_names:
                    params.namespace_names.append(x)

        # Write a dummy source file that includes all the source files we care about
        temp_file_path = os.path.realpath(os.path.dirname(__file__) + '/all.cpp')
        lines = []
        header_files = self.config.get_all_header_files()
        header_files = map(lambda x: '"' + x + '"', header_files)
        for path in header_files:
            lines.append( '#include ' + path )
        common.file_utils.write_file_lines(temp_file_path, lines)

        # Run the preprocessor and extract type info
        params.source_files = [temp_file_path]
        params.include_paths = self.config.get_all_include_paths()

        self.extracted_type_info = common.source_tools.run_castxml(params)
        #self.extracted_type_info.print_all()

        # Delete the temporary file
        os.remove(temp_file_path)

        # Simplify template fields and arguments
        self.simplify_template_types()

        # Now that we have extracted types from native source we merge those mappings into the typemap
        for context in self.config.contexts:
            for enum in context.enumerations: self.type_map[enum.native_type_name] = enum.target_type_name
            #for struct in context.structs: self.type_map[struct.native_type_name] = struct.target_type_name
            #for klass in context.classes: self.type_map[klass.native_type_name] = klass.target_type_name
            #for union in context.unions: self.type_map[union.native_type_name] = union.target_type_name
            #for interface in context.interfaces: self.type_map[interface.native_type_name] = interface.target_type_name

        pprint(self.type_map)

        ## Interface adapters use the native interface as a base class
        #for context in self.config.contexts:
        #    for adapter_info in context.interface_adapters:
        #        if adapter_info.native_type_name not in self.type_map:
        #            raise Exception("Couldn't find in type map: " + adapter_info.native_type_name)
        #        base_type_mapping = self.config.find_type_mapping(adapter_info.native_type_name) # ConfigTypeMapping
        #        adapter_info.target_base_class_mappings.append(base_type_mapping)

        # Generate types that have been marked for generation
        for context in self.config.contexts:

            # Only generate if marked for generation
            if not context.generate:
                continue

            # Generate enuerations
            for enum in context.enumerations:

                # Find the type info
                if not enum.native_type_name in self.extracted_type_info.all:
                    raise Exception("Type info missing: " + enum.native_type_name)
                type_info = self.extracted_type_info.all[enum.native_type_name]

                self.generate_enumeration(enum, type_info)

            ## Generate structs
            #for struct in context.structs:

            #    # Find the type info
            #    if not struct.native_type_name in self.extracted_type_info.all:
            #        raise Exception("Type info missing: " + struct.native_type_name)
            #    type_info = self.extracted_type_info.all[struct.native_type_name]

            #    self.generate_struct(struct, type_info)

            ## Generate classes
            #for klass in context.classes:

            #    # Find the type info
            #    if not klass.native_type_name in self.extracted_type_info.all:
            #        raise Exception("Type info missing: " + klass.native_type_name)
            #    type_info = self.extracted_type_info.all[klass.native_type_name]

            #    self.generate_class(klass, type_info)

            ## Generate interfaces
            #for interface in context.interfaces:

            #    # Find the type info
            #    if not interface.native_type_name in self.extracted_type_info.all:
            #        raise Exception("Type info missing: " + interface.native_type_name)
            #    type_info = self.extracted_type_info.all[interface.native_type_name]

            #    self.generate_interface(interface, type_info)

            ## Generate interface adapters
            #for adapter_info in context.interface_adapters:

            #    # Find the type info
            #    if not adapter_info.native_type_name in self.extracted_type_info.all:
            #        raise Exception("Type info missing: " + adapter_info.native_type_name)
            #    type_info = self.extracted_type_info.all[adapter_info.native_type_name]

            #    self.generate_interface_adapter(adapter_info, type_info)
