package service_common

import (
	"math/rand"
	"net/http"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/feeds/ctxlog"
	"code.justin.tv/feeds/ctxlog/ctxlogaws"
	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/endpoints"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/sts"
	"golang.org/x/net/context"
)

// CreateAWSSession returns an aws session needed to connect to AWS services
func CreateAWSSession(dconf *distconf.Distconf) (*session.Session, []*aws.Config) {
	clientProvider := session.Must(session.NewSession())
	var retConfig []*aws.Config

	awsRegion := dconf.Str("aws_region", "us-west-2").Get()
	if awsRegion != "" {
		awsSTSRegionalEndpoint := endpoints.RegionalSTSEndpoint
		retConfig = append(retConfig, &aws.Config{
			Region:              &awsRegion,
			STSRegionalEndpoint: awsSTSRegionalEndpoint,
		})
	}

	assumedRole := dconf.Str("aws.assume_role", "").Get()
	if assumedRole != "" {
		stsclient := sts.New(session.New(retConfig...))
		arp := &stscreds.AssumeRoleProvider{
			ExpiryWindow: 10 * time.Second,
			RoleARN:      assumedRole,
			Client:       stsclient,
		}
		credentials := credentials.NewCredentials(arp)
		retConfig = append(retConfig, &aws.Config{
			Credentials: credentials,
		})
	}

	return clientProvider, retConfig
}

// CreateAWSSessionWithHttpClient returns an aws session needed to connect to AWS services
func CreateAWSSessionWithHttpClient(dconf *distconf.Distconf, httpClient *http.Client) (*session.Session, []*aws.Config) {
	clientProvider := session.Must(session.NewSession(
		&aws.Config{
			HTTPClient: httpClient,
		},
	))
	var retConfig []*aws.Config

	awsRegion := dconf.Str("aws_region", "us-west-2").Get()
	if awsRegion != "" {
		awsSTSRegionalEndpoint := endpoints.RegionalSTSEndpoint
		retConfig = append(retConfig, &aws.Config{
			Region:              &awsRegion,
			STSRegionalEndpoint: awsSTSRegionalEndpoint,
		})
	}

	assumedRole := dconf.Str("aws.assume_role", "").Get()
	if assumedRole != "" {
		stsclient := sts.New(session.New(retConfig...))
		arp := &stscreds.AssumeRoleProvider{
			ExpiryWindow: 10 * time.Second,
			RoleARN:      assumedRole,
			Client:       stsclient,
		}
		credentials := credentials.NewCredentials(arp)
		retConfig = append(retConfig, &aws.Config{
			Credentials: credentials,
		})
	}

	return clientProvider, retConfig
}

// ContextSend sends the context and, if the error is a throttled error, wraps it in a Throttled error type
func ContextSend(ctx context.Context, req *request.Request, logger ctxlog.Logger) error {
	req.SetContext(ctx)
	err := ctxlogaws.DoAWSSend(req, logger)
	if err != nil {
		err = errors.Wrap(err, "unable to issue aws request")
		return err
	}
	return nil
}

// PerRequestThrottle allows looking at the operation and setting a timeout per operation
type PerOperationTimeout struct {
	TimeoutPerOperation map[string]func() time.Duration
}

// DynamoDBReadOperations returns an array of read operations for DynamoDB
func DynamoDBReadOperations() []string {
	return []string{
		"BatchGetItem",
		"DescribeTable",
		"GetItem",
		"Query",
		"Scan",
	}
}

// SetTimeouts will set each operation to the same timeout function
func (p *PerOperationTimeout) SetTimeouts(operations []string, timeout func() time.Duration) {
	if p.TimeoutPerOperation == nil {
		p.TimeoutPerOperation = make(map[string]func() time.Duration, len(operations))
	}
	for _, op := range operations {
		p.TimeoutPerOperation[op] = timeout
	}
}

// AddToClient adds this per operation timeout to the client
func (p *PerOperationTimeout) AddToClient(c *client.Client) {
	c.Handlers.Build.PushBackNamed(request.NamedHandler{
		Name: "set_operation_context",
		Fn:   p.throttle,
	})
}

func (p *PerOperationTimeout) throttle(req *request.Request) {
	if req == nil || req.Operation == nil {
		return
	}
	timeoutCallback, exists := p.TimeoutPerOperation[req.Operation.Name]
	if !exists {
		return
	}
	requestTimeout := timeoutCallback()
	currentContext := req.Context()
	if currentContext == nil {
		currentContext = context.Background()
	}
	if endsWhen, exists := currentContext.Deadline(); exists {
		if endsWhen.Before(time.Now().Add(requestTimeout)) {
			// The request already ends before this timeout.  Don't bother modifying the context
			return
		}
	}
	newCtx, cancel := context.WithTimeout(currentContext, requestTimeout)
	req.SetContext(newCtx)
	req.Handlers.Complete.PushBackNamed(request.NamedHandler{
		Name: "end_operation_context",
		Fn: func(_ *request.Request) {
			cancel()
		},
	})
}

// ThrottledBackoff allows a struct to share throttling across multiple goroutines via atomic operations
type ThrottledBackoff struct {
	// Should be >= 1.0
	Multiplier float64
	// Should be > 1
	SleepBackoff time.Duration
	MaxSleepTime time.Duration
	Rand         rand.Rand

	// mu only for rand.Rand
	mu          sync.Mutex
	timeToSleep int64
}

// DecreaseBackoff signals that throttling may no longer needed and decreases the throttle sleep amount
func (t *ThrottledBackoff) DecreaseBackoff() {
	for {
		currentSleep := atomic.LoadInt64(&t.timeToSleep)
		if currentSleep == 0 {
			return
		}
		newSleep := currentSleep - t.SleepBackoff.Nanoseconds()
		if newSleep < 0 {
			newSleep = 0
		}
		if atomic.CompareAndSwapInt64(&t.timeToSleep, currentSleep, newSleep) {
			return
		}
	}
}

func (t *ThrottledBackoff) int63n(max int64) int64 {
	t.mu.Lock()
	ret := t.Rand.Int63n(max)
	t.mu.Unlock()
	return ret
}

// ThrottledSleep will sleep till if being throttled, until ctx ends
func (t *ThrottledBackoff) ThrottledSleep(ctx context.Context) error {
	loadedSleepTime := atomic.LoadInt64(&t.timeToSleep)
	if loadedSleepTime == 0 {
		return nil
	}
	timeToSleep := time.Duration(t.int63n(loadedSleepTime))
	select {
	case <-ctx.Done():
		return ctx.Err()
	case <-time.After(timeToSleep):
	}
	return nil
}

// SignalThrottled signals that a throttle needs to happen
func (t *ThrottledBackoff) SignalThrottled() {
	currentSleepTime := atomic.LoadInt64(&t.timeToSleep)
	newSleepTime := int64(float64(currentSleepTime+t.SleepBackoff.Nanoseconds())*t.Multiplier) + 1
	if newSleepTime > t.MaxSleepTime.Nanoseconds() {
		newSleepTime = t.MaxSleepTime.Nanoseconds()
	}
	atomic.StoreInt64(&t.timeToSleep, newSleepTime)
}
