import os
import abc
import shutil
import argparse
import re
import xml.etree.ElementTree
from pprint import pprint

import source_tools
import file_utils


class ConfigTypeMapping:

    """Describes a single type mapping from native to the target language."""

    def __init__(self):
        self.native_type_name = None # The fully qualified native type name from which to generate the code.
        self.target_type_name = None # The fully qualified target language type name.
        self.target_base_class_mappings = [] # Optional ConfigTypeMapping base classes which will be injected during code generation.
        self.output_paths = None # Array, the full paths to the output source files.
        self.context = None # The context the mapping belongs to.  Instance of ConfigMappingContext.
        self.preserve = False # Preserve the existing source file
        self.is_abstract = False # Whether or not the type should be abstract in the target language
        self.custom = None # An arbitrary dict set specific to the target language

    def parse(self, xml_entry, variables_chain):
        """Extracts info from a typemap XML."""

        if 'NativeType' in xml_entry.attrib:
            self.set_native_type_name(xml_entry.attrib['NativeType'])

        if 'TargetType' in xml_entry.attrib:
            self.target_type_name = xml_entry.attrib['TargetType']

        # TODO: Do we need custom data in entries?
        #if 'custom' in data:
        #    self.custom = substitute_variables_deep(data['custom'], variables_chain)

        # TODO: Do we ever need to specify explicit output paths?
        # If not explicitly set then it will be inferred
        #if 'output_paths' in data:
        #    self.output_paths = data['output_paths']

    def set_native_type_name(self, native_type_name):
        self.native_type_name = native_type_name

    def generates_native_code(self):
        """Whether the generated code is meant to be compiled into the SDK."""
        return False

    def generates_binding_code(self):
        """Whether the generated code is meant to be compiled in the target binding language."""
        return True

    def get_custom_value(self, path, default_value=None):
        """Searches the custom data for the given dot separated key."""
        if not self.custom:
            return default_value

        tokens = path.split('.')
        cur = self.custom
        for x in tokens:
            if x not in cur:
                return default_value
            cur = cur[x]

        return cur

    def load_custom_value_as_file_path(self, path, default_value=None):
        """Assumes that the value at the key is a string that specifies a file path to load and returns the file lines as a list."""
        result = default_value
        file_path = self.get_custom_value(path, None)
        if file_path:
            if not os.path.isfile(file_path):
                raise Exception("Could not find snippet file for key '" + path + "': " + file_path)
            result = file_utils.read_file_lines(file_path)
        return result

    def get_mapping_text(self):
        return str(self.native_type_name) + ' --> ' + str(self.target_type_name)


class SimpleClassConfigTypeMapping(ConfigTypeMapping):
    """Simple classes are simple types with fields only and are marshalled directly into the target language."""
    def __init__(self):
        ConfigTypeMapping.__init__(self)


class ClassProxyConfigTypeMapping(ConfigTypeMapping):
    """
    Class proxies are backed by a native instance.
    They will have public methods that are accessible from the target language.
    """
    def __init__(self):
        ConfigTypeMapping.__init__(self)
        self.create_native_in_constructor = True # Create the native instance to wrap in the constructor.

    def parse(self, xml_entry, variables_chain):
        ConfigTypeMapping.parse(self, xml_entry, variables_chain)

        if 'create_native_in_constructor' in data:
            self.create_native_in_constructor = data['create_native_in_constructor']


class EnumerationConfigTypeMapping(ConfigTypeMapping):
    def __init__(self):
        ConfigTypeMapping.__init__(self)


class ConstantConfigTypeMapping(ConfigTypeMapping):
    def __init__(self):
        ConfigTypeMapping.__init__(self)


class InterfaceProxyConfigTypeMapping(ConfigTypeMapping):
    def __init__(self):
        ConfigTypeMapping.__init__(self)
        self.interface_adapter = None


class InterfaceProxyConfigTypeMapping(ConfigTypeMapping):

    """
    Declares that a native interface needs a proxy adapter generated.
    """

    def __init__(self):
        ConfigTypeMapping.__init__(self)

    def generates_native_code(self):
        """Whether the generated code is meant to be compiled into the SDK."""
        return True

    def generates_binding_code(self):
        """Whether the generated code is meant to be compiled in the target binding language."""
        return False


class LambdaConfigTypeMapping(ConfigTypeMapping):
    def __init__(self):
        ConfigTypeMapping.__init__(self)



class ClassConfiguration(ConfigTypeMapping):
    def __init__(self, native_type_name):
        ConfigTypeMapping.__init__(self)
        self.set_native_type_name(native_type_name)
        self.export_methods = []

    def print_configuration(self, indent=''):

        print(indent + "ClassConfiguration:")
        indent = indent + '  '

        print(indent + "native_type_name:      " + self.native_type_name)
        if len(self.export_methods) > 0:
            print(indent + "Methods:")
            for x in self.export_methods:
                print("                       " + x)
        else:
            print(indent + "Methods:           <None>")


class ConfigMappingContext:
    def __init__(self):
        self.name = None # String
        self.settings = None # The owning ConfigSettings
        self.config_xml_path = None # The config file that provided this info.
        self.output_dir = None # The root directory in which to output generated source files.
        self.default_binding_namespace = None # The default namespace for outputted binding files.
        self.default_native_namespace = None # The default namespace for outputted native files.
        self.variables = {}
        self.header_files = []
        self.include_paths = []
        self.enumerations = []
        self.constants = []
        self.simple_classes = []
        self.class_proxies = []
        self.interface_proxies = []
        self.namespaces = []
        self.class_configurations = {} # native type name -> config
        self.interface_adapters = [] # InterfaceAdapter
        self.generate = True # Whether or not to generate the type

    def add_enumeration(self, mapping):
        mapping.context = self
        self.enumerations.append(mapping)

    def add_constant(self, mapping):
        mapping.context = self
        self.constants.append(mapping)

    def add_simple_class(self, mapping):
        mapping.context = self
        self.simple_classes.append(mapping)

    def add_class_proxy(self, klass):
        mapping.context = self
        self.class_proxies.append(mapping)

    def add_interface_proxy(self, mapping):
        mapping.context = self
        self.interface_proxies.append(mapping)

    def add_class_configuration(self, config):
        self.class_configurations[config.native_type_name] = config

    def add_namespace(self, ns):
        if ns not in self.namespaces:
            self.namespaces.append(ns)

    def add_header_file(self, path):
        if path not in self.header_files:
            self.header_files.append(path)

    def add_include_path(self, path):
        if path not in self.include_paths:
            self.include_paths.append(path)

    def get_include_path_relative_header(self, header_path):
        """Searches all include paths for the common prefix and chops it off."""
        dir_path = os.path.dirname(header_path)
        include_paths = self.settings.get_all_include_paths()
        for path in include_paths:
            if dir_path.startswith(path):
                header_path = header_path[len(path)+1:]
                return header_path
        return None


    def print_context(self, indent=''):

        print(indent + "Context:")
        indent = indent + '  '

        print(indent + "Name:                  " + self.name)
        print(indent + "Config path:           " + str(self.config_xml_path))
        print(indent + "Output dir:            " + str(self.output_dir))
        print(indent + "Default binding namespace: " + str(self.default_binding_namespace))
        print(indent + "Default native namespace: " + str(self.default_native_namespace))
        print(indent + "Generate:              " + str(self.generate))

        if len(self.variables) > 0:
            print(indent + "Variables:")
            for x in self.variables:
                print(indent + "                       " + x + ": " + self.variables[x])
        else:
            print(indent + "Variables:             <None>")

        if len(self.include_paths) > 0:
            print(indent + "Include paths:")
            for x in self.include_paths:
                print("                       " + x)
        else:
            print(indent + "Include paths:         <None>")

        if len(self.header_files) > 0:
            print(indent + "Input files:")
            for x in self.header_files:
                print("                       " + x)
        else:
            print(indent + "Input files:           <None>")

        if len(self.namespaces) > 0:
            print(indent + "Namespaces:")
            for x in self.namespaces:
                print("                       " + x)
        else:
            print(indent + "Namespaces:            <None>")

        if len(self.simple_classes) > 0:
            print(indent + "Simple classes:")
            for x in self.simple_classes:
                print("                       " + x.get_mapping_text())
        else:
            print(indent + "Simple classes:        <None>")

        if len(self.class_proxies) > 0:
            print(indent + "Class proxies:")
            for x in self.class_proxies:
                print("                       " + x.get_mapping_text())
        else:
            print(indent + "Class proxies:               <None>")

        if len(self.enumerations) > 0:
            print(indent + "Enumerations:")
            for x in self.enumerations:
                print("                       " + x.get_mapping_text())
        else:
            print(indent + "Enumerations:                 <None>")

        if len(self.interface_proxies) > 0:
            print(indent + "Interface proxies:")
            for x in self.interface_proxies:
                print("                       " + x.get_mapping_text())
        else:
            print(indent + "Interface proxies:            <None>")

        #if len(self.interface_adapters) > 0:
        #    print(indent + "Interface Adapters:")
        #    for x in self.interface_adapters:
        #        print("                       " + x.native_type_name)
        #else:
        #    print(indent + "Interface Adapters:            <None>")

        #if len(self.class_configurations) > 0:
        #    print(indent + "Class configurations:")
        #    for x in self.class_configurations:
        #        self.class_configurations[x].print_configuration(indent + '  ')
        #else:
        #    print(indent + "Class configurations:  <None>")

        print('')


class ConfigSettings:

    def __init__(self, language):
        self.language = language
        self.indent = '    '
        self.variables = {} # The exported variables
        self.contexts = [] # ConfigMappingContext
        self.default_binding_namespace_delimiter = None # The namespace delimiter.
        self.precompiled_header_path = None # An optional pch header to place at the top of generated native source files.

    def add_context(self, context):
        if self.find_context(context.name) != None:
            raise Exception('Context already added: ' + context.name)
        context.settings = self
        self.contexts.append(context)

    def find_context(self, name):
        contexts = filter(lambda x: x.name == name, self.contexts)
        if len(contexts) > 0:
            return contexts[0]
        return None

    def get_all_header_files(self):
        header_files = []
        for context in self.contexts:
            for x in context.header_files: header_files.append(x)
        return header_files

    def get_all_include_paths(self):
        include_paths = []
        for context in self.contexts:
            for x in context.include_paths: include_paths.append(x)
        return include_paths

    def find_simple_class_mapping(self, native_type_name):
        for context in self.contexts:
            entries = filter(lambda x: x.native_type_name == native_type_name, context.simple_classes)
            if len(entries) > 0:
                return entries[0]
        else:
            return None

    def find_class_proxy_mapping(self, native_type_name):
        for context in self.contexts:
            classes = filter(lambda x: x.native_type_name == native_type_name, context.class_proxies)
            if len(classes) > 0:
                return classes[0]
        else:
            return None

    def find_enumeration_mapping(self, native_type_name):
        for context in self.contexts:
            enums = filter(lambda x: x.native_type_name == native_type_name, context.enumerations)
            if len(enums) > 0:
                return enums[0]
        return None

    def find_interface_proxy_mapping(self, native_type_name):
        for context in self.contexts:
            interfaces = filter(lambda x: x.native_type_name == native_type_name, context.interface_proxies)
            if len(interfaces) > 0:
                return interfaces[0]
        return None

    def find_class_configuration(self, native_type_name):
        for context in self.contexts:
            if native_type_name in context.class_configurations:
                return context.class_configurations[native_type_name]
        else:
            return None

    def find_type_mapping(self, native_type_name):
        type_mapping = self.find_simple_class_mapping(native_type_name)
        if type_mapping:
            return type_mapping

        type_mapping = self.find_class_proxy_mapping(native_type_name)
        if type_mapping:
            return type_mapping

        type_mapping = self.find_enumeration_mapping(native_type_name)
        if type_mapping:
            return type_mapping

        type_mapping = self.find_interface_proxy_mapping(native_type_name)
        if type_mapping:
            return type_mapping

        return None

    def print_config(self, indent=''):

        print("")
        print(indent + "Config Settings:")
        indent = indent + '  '
        print(indent + "Language:              " + self.language)
        print(indent + "Default binding namespace delim: " + str(self.default_binding_namespace_delimiter))
        print(indent + "PCH header path:       " + str(self.precompiled_header_path))

        if len(self.variables) > 0:
            print(indent + "Variables:")
            for x in self.variables:
                print(indent + "                       " + x + ": " + self.variables[x])
        else:
            print(indent + "Variables:             <None>")

        for context in self.contexts:
            context.print_context(indent);

        print("")


def substitute_variables(str, variables_chain):
    """Replaces all instances of variables in the given string with the variable values."""
    changed = True
    count = 10
    while changed:
        changed = False
        if count == 0:
            raise Exception("Variable substitution too deep: " + str)
        count = count - 1

        for chain in variables_chain:
            for key in chain:
                var = '$(' + key.lower() + ')'
                val = chain[key]
                new_str = str.replace(var, val)
                if new_str != str:
                    changed = True
                    str = new_str

    index = str.find('$(')

    if index >= 0:
        raise Exception("Unresolved variable: " + str)

    return str.encode('utf-8')


def substitute_variables_deep(obj, variables_chain):
    if obj is None:
        return obj
    # Found a string
    string_types = (str, unicode) if str is bytes else (str, bytes)
    if isinstance(obj, string_types):
        return substitute_variables(obj, variables_chain)
    # List
    elif isinstance(obj, list):
        for i in range(0, len(obj)):
            obj[i] = substitute_variables_deep(obj[i], variables_chain)
    # Map
    elif isinstance(obj, dict):
        for key in obj:
            obj[key] = substitute_variables_deep(obj[key], variables_chain)
    return obj


def parse_group(xml_group, config_type, variables_chain):

    result = []

    for xml_entry in xml_group.findall("Entry"):
        item = config_type()
        item.parse(xml_entry, variables_chain)

        result.append(item)

    return result


def load_config_file(config_xml_path, config, generate):

    """
    Parses the XML config file used to describe the types to generate and the source files to examine.
    """

    print("Reading config file: " + config_xml_path + "...")

    doc = xml.etree.ElementTree.parse(config_xml_path)
    xml_root = doc.getroot()

    context = ConfigMappingContext()
    context.name = xml_root.attrib["Name"]
    config.add_context(context)

    context.config_xml_path = config_xml_path
    context.generate = generate

    # The chain of variables to examine when searching for a value
    variables_chain = [context.variables, config.variables]

    # A variable for the directory of the config file
    context.variables['__dir__'] = file_utils.fix_path( os.path.dirname(config_xml_path) )

    # Read variables
    for xml_Variables in xml_root.findall("Variables"):
        for xml_Variable in xml_Variables.findall("Variable"):
            key = xml_Variable.attrib["Key"]
            val = substitute_variables( xml_Variable.text.strip(), variables_chain )
            if "Export" in xml_Variable.attrib and xml_Variable.attrib["Export"] == "true":
                config.variables[key] = val
            else:
                context.variables[key] = val

    if "DefaultTargetNamespaceDelimeter" in xml_root.attrib:
        if config.default_binding_namespace_delimiter is None:
            config.default_binding_namespace_delimiter = xml_root.attrib["DefaultBindingNamespaceDelimeter"]
        else:
            if config.default_binding_namespace_delimiter != xml_root.attrib["DefaultBindingNamespaceDelimeter"]:
                raise Exception("Default binding namespace delimiters don't match: " + config.default_binding_namespace_delimiter + ", " + xml_root.attrib["DefaultBindingNamespaceDelimeter"])
        config.default_binding_namespace_delimiter = substitute_variables( xml_root.attrib["DefaultBindingNamespaceDelimeter"], variables_chain )

    if "OutputDir" in xml_root.attrib:
        context.output_dir = file_utils.fix_path( substitute_variables( xml_root.attrib["OutputDir"], variables_chain ) )

    if "DefaultTargetNamespace" in xml_root.attrib:
        context.default_binding_namespace = substitute_variables( xml_root.attrib["DefaultTargetgNamespace"], variables_chain )

    if "DefaultNativeNamespace" in xml_root.attrib:
        context.default_native_namespace = substitute_variables( xml_root.attrib["DefaultNativeNamespace"], variables_chain )

    for xml_HeaderFiles in xml_root.findall("HeaderFiles"):
        paths = file_utils.get_lines_from_text(xml_HeaderFiles.text, trim=True)
        for path in paths:
            path = file_utils.fix_path( substitute_variables(path, variables_chain) )
            context.add_header_file(path)

    for xml_IncludePaths in xml_root.findall("IncludePaths"):
        paths = file_utils.get_lines_from_text(xml_IncludePaths.text, trim=True)
        for path in paths:
            path = file_utils.fix_path( substitute_variables(path, variables_chain) )
            context.add_include_path(path)

    for xml_Namespaces in xml_root.findall("Namespaces"):
        namespaces = file_utils.get_lines_from_text(xml_Namespaces.text, trim=True)
        for ns in namespaces:
            ns = substitute_variables(ns, variables_chain)
            context.add_namespace(ns)

    for xml_group in xml_root.findall("Enumerations"):
        group = parse_group( xml_group, EnumerationConfigTypeMapping, variables_chain )
        for x in group:
            context.add_enumeration(x)

    for xml_group in xml_root.findall("Constants"):
        group = parse_group( xml_group, ConstantConfigTypeMapping, variables_chain )
        for x in group:
            context.add_constant(x)

    for xml_group in xml_root.findall("SimpleClasses"):
        group = parse_group( xml_group, SimpleClassConfigTypeMapping, variables_chain )
        for x in group:
            context.add_simple_class(x)

    for xml_group in xml_root.findall("ClassProxies"):
        group = parse_group( xml_group, ClassProxyConfigTypeMapping, variables_chain )
        for x in group:
            context.add_class_proxy(x)

    for xml_group in xml_root.findall("InterfaceProxies"):
        group = parse_group( xml_group, InterfaceProxyConfigTypeMapping, variables_chain )
        for x in group:
            context.add_interface_proxy(x)

    #if "class_configurations" in json_config:
    #    for json_class_config in json_config["class_configurations"]:
    #        class_config = ClassConfiguration(json_class_config["native_type_name"])
    #        class_config.export_methods = json_class_config["methods"]
    #        context.add_class_configuration(class_config)

    print("Done reading config file: " + config_xml_path)

    return config;


def add_common_command_line_switches(parser):

    parser.add_argument(
        '--generate-config',
        required=True,
        metavar='<config_file>',
        action='append',
        help='Specifies a config file to generate types for.'
    )

    parser.add_argument(
        '--mapping-config',
        required=False,
        metavar='<config_file>',
        action='append',
        help='Specifies a config file that contains type mappings but not to generate types for.'
    )

    parser.add_argument(
        '--pch-path',
        required=False,
        metavar='<path>',
        help='Specifies the precompiled header file to be placed at the top of generated source files.'
    )

    parser.add_argument(
        '--language',
        required=False,
        metavar='<language>',
        help='Specifies the language to generate bindings for.'
    )


def process_command_line(parser, command_line_args, config):

    add_common_command_line_switches(parser)

    args = parser.parse_args(command_line_args)

    # Load configs that provide type mappings but are not being generated
    if args.mapping_config:
        for path in args.mapping_config:
            config = load_config_file(path, config, generate=False)

    # Load configs that will be used to generate types
    for path in args.generate_config:
        config = load_config_file(path, config, generate=True)

    # Retain the pch path
    if args.pch_path:
        config.precompiled_header_path = args.pch_path

    return config


if __name__ == "__main__":

    config = ConfigSettings('java')
    load_config_file('X:/build/bindings/core.bindings.config.xml', config, False)
#    load_config_file('X:/build/tools/bindings/test/enumeration.xml', config, False)
