/*
Package twitchhttp enables quicker creation of production-ready HTTP clients and servers.
*/
package twitchhttp

import (
	"fmt"
	"io"
	"net/http"
	"net/url"
	"strings"
	"sync"

	"code.justin.tv/common/chitin"

	"golang.org/x/net/context"
)

// Client ensures that a Trace-compatible http.Client is used by default.
// If you want to roll your own RoundTripper, use SetHTTPClient
type Client struct {
	sync.RWMutex
	host              *url.URL
	resolveHTTPClient func(context.Context) (*http.Client, error)
}

// NewClient allocates and returns a new Client.
func NewClient(host string) (*Client, error) {
	if host == "" {
		return nil, fmt.Errorf("host cannot be blank")
	}

	if !strings.HasPrefix(host, "http") {
		host = fmt.Sprintf("http://%v", host)
	}

	u, err := url.Parse(host)
	if err != nil {
		return nil, err
	}

	c := Client{
		host: u,
		resolveHTTPClient: func(ctx context.Context) (*http.Client, error) {
			if ctx == nil {
				return nil, fmt.Errorf("chitin.Client needs a non-nil Context")
			}
			return chitin.Client(ctx), nil
		},
	}

	return &c, nil
}

// SetHTTPClient accepts a func(context.Context) (*http.Client, error) if you want your http client
// to be a function of the context, or you can just pass in a *http.Client
func (c *Client) SetHTTPClient(httpClient interface{}) error {
	c.Lock()
	defer c.Unlock()
	switch client := httpClient.(type) {
	case *http.Client:
		c.resolveHTTPClient = func(_ context.Context) (*http.Client, error) {
			return client, nil
		}
	case func(context.Context) (*http.Client, error):
		c.resolveHTTPClient = client
	default:
		return fmt.Errorf("invalid input for SetHTTPClient: %v", client)
	}

	return nil
}

// Do executes a requests using the given Context for Trace support
func (c *Client) Do(ctx context.Context, req *http.Request) (*http.Response, error) {
	c.RLock()
	httpClient, err := c.resolveHTTPClient(ctx)
	c.RUnlock()
	if err != nil {
		return nil, err
	}

	return httpClient.Do(req)
}

// NewRequest creates an *http.Request using the configured host as the base for the path.
func (c *Client) NewRequest(method string, path string, body io.Reader) (*http.Request, error) {
	u, err := url.Parse(path)
	if err != nil {
		return nil, err
	}

	return http.NewRequest(method, c.host.ResolveReference(u).String(), body)
}
