package dbx_test

import (
	"errors"
	"testing"

	"code.justin.tv/devrel/dbx"

	sqlmock "github.com/DATA-DOG/go-sqlmock"
	"github.com/stretchr/testify/require"
)

func TestTransactions_Commit(t *testing.T) {
	mockSQL, db := NewTestDBX()
	mockSQL.ExpectBegin()
	mockSQL.ExpectExec("INSERT INTO heroes").WillReturnResult(sqlmock.NewResult(1, 1))
	mockSQL.ExpectExec("INSERT INTO heroes").WillReturnResult(sqlmock.NewResult(1, 1))
	mockSQL.ExpectCommit()

	h := &Hero{}
	txCtx := dbx.MustBegin(ctxBck, db.DB) // begin transaction
	defer dbx.RollbackUnlessComitted(txCtx, nil)

	err := db.InsertOne(txCtx, "heroes", h) // do stuff in the transaction
	require.NoError(t, err)
	err = db.InsertOne(txCtx, "heroes", h)
	require.NoError(t, err)

	err = dbx.Commit(txCtx) // commit transaction
	require.NoError(t, err)

	dbx.RollbackUnlessComitted(txCtx, nil) // rollback does nothing, transaction already committed
	dbx.RollbackUnlessComitted(txCtx, nil)
	requireExpectations(t, mockSQL)
}

func TestTransactions_Rollback(t *testing.T) {
	mockSQL, db := NewTestDBX()
	mockSQL.ExpectBegin()
	mockSQL.ExpectExec("INSERT INTO heroes").WillReturnResult(sqlmock.NewResult(1, 1))
	mockSQL.ExpectRollback()

	h := &Hero{}
	txCtx := dbx.MustBegin(ctxBck, db.DB) // begin transaction
	err := db.InsertOne(txCtx, "heroes", h)
	require.NoError(t, err)
	dbx.RollbackUnlessComitted(txCtx, nil) // rollback transaction
	requireExpectations(t, mockSQL)
}

func TestTransactions_Rollback_WithErrHandler_NoError(t *testing.T) {
	mockSQL, db := NewTestDBX()
	mockSQL.ExpectBegin()
	mockSQL.ExpectExec("INSERT INTO heroes").WillReturnResult(sqlmock.NewResult(1, 1))
	mockSQL.ExpectRollback()

	h := &Hero{}
	txCtx := dbx.MustBegin(ctxBck, db.DB) // begin transaction
	err := db.InsertOne(txCtx, "heroes", h)
	require.NoError(t, err)

	errHandlerCalled := false
	dbx.RollbackUnlessComitted(txCtx, func(err error) {
		errHandlerCalled = true
	})
	require.False(t, errHandlerCalled) // should not have error
	requireExpectations(t, mockSQL)
}

func TestTransactions_Rollback_WithErrHandler_Error(t *testing.T) {
	mockSQL, db := NewTestDBX()
	mockSQL.ExpectBegin()
	mockSQL.ExpectExec("INSERT INTO heroes").WillReturnResult(sqlmock.NewResult(1, 1))
	expectedRollback := mockSQL.ExpectRollback()
	expectedRollback.WillReturnError(errors.New("Rollback Error Yo"))

	h := &Hero{}
	txCtx := dbx.MustBegin(ctxBck, db.DB) // begin transaction
	err := db.InsertOne(txCtx, "heroes", h)
	require.NoError(t, err)

	var rollbackErr error
	dbx.RollbackUnlessComitted(txCtx, func(err error) {
		rollbackErr = err
	})
	require.EqualError(t, rollbackErr, "Rollback Error Yo") // error handler called
	requireExpectations(t, mockSQL)
}

func TestTransactions_DeferRollback_WithErrHandler_Error(t *testing.T) {
	mockSQL, db := NewTestDBX()
	mockSQL.ExpectBegin()
	mockSQL.ExpectExec("INSERT INTO heroes").WillReturnResult(sqlmock.NewResult(1, 1))
	expectedRollback := mockSQL.ExpectRollback()
	expectedRollback.WillReturnError(errors.New("Rollback Error Yo"))

	var rollbackErr error
	rollbackErrHandler := func(err error) {
		rollbackErr = err
	}

	doTx := func() {
		h := &Hero{}
		txCtx := dbx.MustBegin(ctxBck, db.DB) // begin transaction
		defer dbx.RollbackUnlessComitted(txCtx, rollbackErrHandler)

		err := db.InsertOne(txCtx, "heroes", h)
		require.NoError(t, err)

		// do not commit, it should rollback
	}
	doTx()

	require.EqualError(t, rollbackErr, "Rollback Error Yo") // error handler called
	requireExpectations(t, mockSQL)
}

func TestTransactions_CommitErrors(t *testing.T) {
	mockSQL, db := NewTestDBX()
	mockSQL.ExpectBegin()
	mockSQL.ExpectRollback()
	mockSQL.ExpectCommit()

	err := dbx.Commit(ctxBck)
	require.EqualError(t, err, "dbx Commit: missing transaction in context")

	txCtx := dbx.MustBegin(ctxBck, db.DB)  // begin transaction
	dbx.RollbackUnlessComitted(txCtx, nil) // rollback
	err = dbx.Commit(txCtx)                // commit after rollback
	require.EqualError(t, err, "dbx Commit: already commited or rolled back")
}
