package internal

import (
	"fmt"
	"go/ast"
	"io"
	"strings"
)

type interfaceData struct {
	TypeSpec    *ast.TypeSpec // type <Name> interface { ... }
	Name        string
	Methods     *ast.FieldList // methods (*ast.FuncType) and other embedded interfaces (*ast.Ident)
	FileImports []importData   // imports on the file where this interface is defined
}

type importData struct {
	Nickname string
	Path     string
}

func (d importData) String() string {
	str := d.Path
	if d.Nickname != "" && d.Nickname != "<nil>" {
		str = d.Nickname + " " + str
	}

	return str
}

type structData struct {
	Imports   []importData
	Package   string
	Type      string
	Functions []functionData

	WithTimings bool // include a TimingFunction to track timings
}

// Fill Functions (methods) from a field list interface
func (sd *structData) fillFieldList(fieldlist *ast.FieldList, interfaces map[string]*interfaceData) error {
	for _, field := range fieldlist.List {
		if identType, ok := field.Type.(*ast.Ident); ok {
			if intrfc, ok := interfaces[identType.Name]; ok {
				// This is an embedded interface
				// Recursively add functions (methods) from the embedded interface
				sd.fillFieldList(intrfc.Methods, interfaces)
			}
			continue
		}

		funcType, ok := field.Type.(*ast.FuncType)
		if !ok {
			continue
		}

		fd := functionDataFromFieldType(sd.Package, funcType)
		fd.Name = field.Names[0].Name

		sd.Functions = append(sd.Functions, fd)
	}

	return nil
}

type lineWriter struct {
	total int
	err   error
	w     io.Writer
}

func (w *lineWriter) Output() (int, error) {
	return w.total, w.err
}

func (w *lineWriter) write(format string, args ...interface{}) {
	if w.err != nil {
		return
	}

	n, err := fmt.Fprintln(w.w, fmt.Sprintf(format, args...))
	w.total += n
	w.err = err

	return
}

func (sd *structData) Write(w io.Writer) (int, error) {
	lw := &lineWriter{w: w}

	lw.write(`import (`)
	lw.write(`"fmt"`)
	if sd.WithTimings {
		lw.write(`"context"`)
		lw.write(`"time"`)
	}
	lw.write(`"code.justin.tv/chat/golibs/errx"`)
	for _, id := range sd.Imports {
		lw.write(id.String())
	}
	lw.write(`)`)

	typeName := fmt.Sprintf("%sErrx", sd.Type)

	lw.write("type %s struct {", typeName)
	if sd.Package != "" {
		lw.write("%s %s.%s", sd.Type, sd.Package, sd.Type)
	} else {
		lw.write("%s %s", sd.Type, sd.Type)
	}
	if sd.WithTimings {
		lw.write("TimingFunc func(ctx context.Context, d time.Duration, method string, err error)")
	}
	lw.write("}")

	// For each method
	for _, f := range sd.Functions {

		// Input params: e.g. ctx context.Context, key string, opts ...int
		inParams := []string{}
		for _, in := range f.In {
			inParams = append(inParams, fmt.Sprintf("%s %s", in.Name, in.Type))
		}

		// Input params passed to the wrapped method: e.g. ctx, key, int...
		funcCallParams := []string{}
		for _, in := range f.In {
			pExpr := in.Name
			if strings.HasPrefix(in.Type, "...") {
				pExpr = pExpr + "..." // variadic arguments
			}
			funcCallParams = append(funcCallParams, pExpr)
		}

		// Output result types: e.g. int, error
		outResultTypes := []string{}
		for _, out := range f.Out {
			outResultTypes = append(outResultTypes, out.Type)
		}

		// Results returned by the wrapped method: e.g. out0, err
		funcCallResults := []string{}
		for _, out := range f.Out {
			funcCallResults = append(funcCallResults, out.Name)
		}

		// Results returned: e.g. "out0, errx.NewWithSkip(err, 1)"
		outResults := []string{}
		errVar := "nil"
		for _, out := range f.Out {
			oExpr := out.Name
			if out.Type == "error" {
				errVar = out.Name
				oExpr = fmt.Sprintf("errx.NewWithSkip(%s, 1)", errVar)
			}
			outResults = append(outResults, oExpr)
		}

		// Method declaration
		lw.write("// %s wraps %s errors with errx", f.Name, sd.Type)
		lw.write("func (e *%s) %s(%s) (%s) {", typeName, f.Name, csv(inParams), csv(outResultTypes))

		if sd.WithTimings {
			lw.write(`start := time.Now()`)
		}

		// Call wrapped method
		funcCall := fmt.Sprintf("e.%s.%s(%s)", sd.Type, f.Name, csv(funcCallParams))
		hasReturnResults := len(f.Out) > 0
		if hasReturnResults {
			lw.write("%s := %s", csv(funcCallResults), funcCall)
		} else {
			lw.write(funcCall)
		}

		if sd.WithTimings {
			ctxVar := "context.Background()"
			if len(f.In) > 0 && f.In[0].Type == "context.Context" {
				ctxVar = f.In[0].Name
			}
			lw.write(`if e.TimingFunc != nil {`)
			lw.write(`  e.TimingFunc(%s, time.Since(start), "%s", %s)`, ctxVar, f.Name, errVar)
			lw.write(`}`)
		}

		if hasReturnResults {
			lw.write("return %s", csv(outResults))
		}

		lw.write("}")
	}

	return lw.Output()
}

func csv(list []string) string {
	return strings.Join(list, ", ")
}

func functionDataFromFieldType(pkg string, funcType *ast.FuncType) functionData {
	var f functionData

	// f.In: Params
	if funcType.Params != nil && len(funcType.Params.List) > 0 {
		f.In = expandFields(pkg, "in", funcType.Params.List)
	}

	// f.Out: Results returned by the wrapped method
	if funcType.Results != nil && len(funcType.Results.List) > 0 {
		f.Out = expandFields(pkg, "r", funcType.Results.List)
		last := &f.Out[len(f.Out)-1]
		if last.Type == "error" {
			last.Name = "err" // convenience name
		}
	}

	return f
}

// expandFieldNames makes a simpler list of parameters from a []*ast.Field lists.
// The ast.Field type has Names that may be null or a list with multiple names for parameter lists.
// For example, for this declaration: func(key string, num1, num2 int) (int, error)
//  .Params.List  == [{Names: ["key"], Type: "string"}, {Names: ["num1", "num2"], Type: "int"}]
//  .Results.List == [{Names: nil, Type: "int"}, {Names: nil, Type: "error"}]
// This mehod expands to one parameterData struct per parameter:
//  expandFields(.Params.List)  == [{Name: "key", Type: "string"}, {Name: "num1": Type: "int"}, {Name: "num2", Type: "int"}]
//  expandFields(.Results.List) == [{Name: "out0", Type: "int"}, {Name: "out1", Type: "error"}]
func expandFields(pkg string, prefix string, list []*ast.Field) []parameterData {
	expanded := []parameterData{}
	for i, field := range list {
		typ := exprToType(pkg, field.Type)
		if len(field.Names) == 0 { // no names => use "in0", "in1", etc.
			name := fmt.Sprintf("%s%d", prefix, i)
			expanded = append(expanded, parameterData{Name: name, Type: typ, Expr: field.Type})
		} else { // one or more names => use the identifier name "ctx", "key", etc.
			for _, ident := range field.Names {
				name := ident.Name
				expanded = append(expanded, parameterData{Name: name, Type: typ, Expr: field.Type})
			}
		}
	}
	return expanded
}

// turns an ast.Expr type into it's corresponding golang string
func exprToType(pkg string, expr ast.Expr) string {
	switch v := expr.(type) {
	case *ast.SelectorExpr:
		// if v.X doesn't cast to ast.Ident, it's a case we haven't seen before and panic
		return v.X.(*ast.Ident).Name + "." + v.Sel.Name
	case *ast.Ident:
		// NOTE: this is the only place pkg is used and it bloats everything :/
		if v.IsExported() && pkg != "" {
			return pkg + "." + v.Name
		}
		return v.Name
	case *ast.ArrayType:
		return "[]" + exprToType(pkg, v.Elt)
	case *ast.Ellipsis:
		return "..." + exprToType(pkg, v.Elt)
	case *ast.MapType:
		return "map[" + exprToType(pkg, v.Key) + "]" + exprToType(pkg, v.Value)
	case *ast.ChanType:
		var name string
		switch v.Dir {
		case 1:
			name = "chan<- "
		case 2:
			name = "<-chan "
		case 3:
			name = "chan "
		}

		return name + exprToType(pkg, v.Value)
	case *ast.StarExpr:
		return "*" + exprToType(pkg, v.X)
	case *ast.FuncType:
		fd := functionDataFromFieldType(pkg, v)
		return fd.String()
	case *ast.InterfaceType:
		return "interface{}"
	default:
		fmt.Printf("warning unhandled type: %s\n", v)
	}

	return ""
}

type functionData struct {
	Name string

	In  []parameterData // Input parameters
	Out []parameterData // Variables for returned results on the wrapped method
}

func (fd functionData) String() string {
	var ins []string
	for _, in := range fd.In {
		ins = append(ins, in.String())
	}

	params := "(" + strings.Join(ins, ",") + ")"

	var outs []string
	for _, out := range fd.Out {
		outs = append(outs, out.String())
	}

	results := "(" + strings.Join(outs, ",") + ")"

	return "func" + params + " " + results
}

type parameterData struct {
	Name string
	Type string
	Zero string

	// Store expr later for other use
	Expr ast.Expr
}

func (pd parameterData) String() string {
	if pd.Name == "" {
		return pd.Type
	}
	return pd.Name + " " + pd.Type
}
