package userstore

import (
	"context"
	"database/sql"
	"errors"
	"fmt"

	"code.justin.tv/vapour/authproxy/pkg"
	gotwitchapi "github.com/dankeroni/gotwitch"
	"github.com/go-sql-driver/mysql"
	"golang.org/x/oauth2"
)

type UserStore struct {
	sqlClient    *sql.DB
	oauth2Config *oauth2.Config
	api          *gotwitchapi.TwitchAPI

	users map[string]pkg.User
}

type User struct {
	login string
	id    string

	tokenSource oauth2.TokenSource
}

type SQLUser struct {
	UserID       string
	AccessToken  string
	RefreshToken string
	Expiry       mysql.NullTime
}

func (u *User) Token() (*oauth2.Token, error) {
	return u.tokenSource.Token()
}

func (u *User) ID() string {
	return u.id
}

func (u *User) Login() string {
	return u.login
}

func (u *User) SaveToDB(sqlClient *sql.DB) error {
	const queryF = `INSERT INTO Users
	(user_id, access_token, refresh_token, expiry)
	VALUES (?, ?, ?, ?)
	ON DUPLICATE KEY UPDATE
	access_token=?, refresh_token=?, expiry=?;`

	token, err := u.Token()
	if err != nil {
		return err
	}

	_, err = sqlClient.Exec(queryF, u.ID(), token.AccessToken, token.RefreshToken, token.Expiry,
		token.AccessToken, token.RefreshToken, token.Expiry)
	if err != nil {
		return err
	}

	return nil
}

func New(sqlClient *sql.DB, oauth2Config *oauth2.Config, api *gotwitchapi.TwitchAPI) *UserStore {
	return &UserStore{
		sqlClient:    sqlClient,
		oauth2Config: oauth2Config,
		api:          api,

		users: make(map[string]pkg.User),
	}
}

func (us *UserStore) Load() error {
	const queryF = `SELECT user_id, access_token, refresh_token, expiry FROM Users;`

	rows, err := us.sqlClient.Query(queryF)
	if err != nil {
		return err
	}

	validUsers := []SQLUser{}

	for rows.Next() {
		sqlUser := SQLUser{}

		err = rows.Scan(&sqlUser.UserID, &sqlUser.AccessToken, &sqlUser.RefreshToken, &sqlUser.Expiry)
		if err != nil {
			return err
		}
		if !sqlUser.Expiry.Valid {
			return errors.New("Invalid expiry time")
		}

		validUsers = append(validUsers, sqlUser)
	}

	for _, sqlUser := range validUsers {
		rawToken := &oauth2.Token{
			AccessToken:  sqlUser.AccessToken,
			TokenType:    "bearer",
			RefreshToken: sqlUser.RefreshToken,
			Expiry:       sqlUser.Expiry.Time,
		}

		tokenSource := us.oauth2Config.TokenSource(context.Background(), rawToken)

		token, err := tokenSource.Token()
		if err != nil {
			return err
		}

		// Validate token
		self, _, err := us.api.ValidateOAuthTokenSimple(token.AccessToken)
		if err != nil {
			return err
		}

		if self.UserID != sqlUser.UserID {
			return errors.New("mismatching user IDs wtf")
		}

		fmt.Printf("Verified user %s(%s)\n", self.Login, self.UserID)

		u := &User{
			login: self.Login,
			id:    self.UserID,

			tokenSource: tokenSource,
		}

		u.SaveToDB(us.sqlClient)

		us.users[u.ID()] = u
	}

	return nil
}

func (us *UserStore) Get(userID string) (pkg.User, bool) {
	u, ok := us.users[userID]
	return u, ok
}

func (us *UserStore) Add(code string) error {
	rawToken, err := us.oauth2Config.Exchange(context.Background(), code)
	if err != nil {
		return err
	}

	self, _, err := us.api.ValidateOAuthTokenSimple(rawToken.AccessToken)
	if err != nil {
		return err
	}

	user := &User{
		login: self.Login,
		id:    self.UserID,

		tokenSource: us.oauth2Config.TokenSource(context.Background(), rawToken),
	}

	us.users[self.UserID] = user

	err = user.SaveToDB(us.sqlClient)
	if err != nil {
		return err
	}

	return nil
}
