package xray

import (
	"bytes"
	"context"
	"database/sql"
	"database/sql/driver"
	"fmt"
	"io"
	"net/url"
	"reflect"
	"strings"
	"time"
)

const undetectableXRaySQLInfo = "Unknown"

func dsnParseOk(u *url.URL, err error) bool {
	return err == nil && (u.Scheme != "" || u.User != nil || u.RawQuery != "" || strings.Contains(u.Path, "@"))
}

func parseDsnForDB(db *DB, dsn string) error {
	// Detect if DSN is a URL or not, set appropriate attribute
	urlDsn := dsn
	if !strings.Contains(dsn, "//") {
		urlDsn = "//" + urlDsn
	}
	// Here we're trying to detect things like `host:port/database` as a URL, which is pretty hard
	// So we just assume that if it's got a scheme, a user, or a query that it's probably a URL
	if u, err := url.Parse(urlDsn); dsnParseOk(u, err) {
		// Check that this isn't in the form of user/pass@host:port/db, as that will shove the host into the path
		if strings.Contains(u.Path, "@") {
			u, _ = url.Parse(fmt.Sprintf("%s//%s%%2F%s", u.Scheme, u.Host, u.Path[1:]))
		}

		// Strip password from user:password pair in address
		if u.User != nil {
			uname := u.User.Username()

			// Some drivers use "user/pass@host:port" instead of "user:pass@host:port"
			// So we must manually attempt to chop off a potential password.
			// But we can skip this if we already found the password.
			if _, ok := u.User.Password(); !ok {
				uname = strings.Split(uname, "/")[0]
			}

			u.User = url.User(uname)
		}

		// Strip password from query parameters
		q := u.Query()
		q.Del("password")
		u.RawQuery = q.Encode()

		db.url = u.String()
		if !strings.Contains(dsn, "//") {
			db.url = db.url[2:]
		}
		return nil
	}
	// We don't *think* it's a URL, so now we have to try our best to strip passwords from
	// some unknown DSL. We attempt to detect whether it's space-delimited or semicolon-delimited
	// then remove any keys with the name "password" or "pwd". This won't catch everything, but
	// from surveying the current (Jan 2017) landscape of drivers it should catch most.
	var err error
	db.connectionString, err = stripPasswords(dsn)
	return err
}

// SQL opens a normalized and traced wrapper around an *sql.DB connection.
// It uses `sql.Open` internally and shares the same function signature.
// To ensure passwords are filtered, it is HIGHLY RECOMMENDED that your DSN
// follows the format: `<schema>://<user>:<password>@<host>:<port>/<database>`
func (x *XRay) SQL(driver, dsn string) (*DB, error) {
	rawDB, err := sql.Open(driver, dsn)
	if err != nil {
		return nil, err
	}

	db := &DB{x: x, db: rawDB}
	if err := parseDsnForDB(db, dsn); err != nil {
		return nil, err
	}

	// Detect database type and use that to populate attributes
	var detectors []func(*DB) error
	switch driver {
	case "postgres":
		detectors = append(detectors, postgresDetector)
	case "mysql":
		detectors = append(detectors, mysqlDetector)
	default:
		detectors = append(detectors, postgresDetector, mysqlDetector, mssqlDetector, oracleDetector)
	}
	for _, detector := range detectors {
		if detector(db) == nil {
			break
		}
		db.databaseType = undetectableXRaySQLInfo
		db.databaseVersion = undetectableXRaySQLInfo
		db.user = undetectableXRaySQLInfo
		db.dbname = undetectableXRaySQLInfo
	}

	// There's no standard to get SQL driver version information
	// So we invent an interface by which drivers can provide us this data
	type versionedDriver interface {
		Version() string
	}

	d := db.db.Driver()
	if vd, ok := d.(versionedDriver); ok {
		db.driverVersion = vd.Version()
	} else {
		t := reflect.TypeOf(d)
		for t.Kind() == reflect.Ptr {
			t = t.Elem()
		}
		db.driverVersion = t.PkgPath()
	}

	return db, nil
}

// DB copies the interface of sql.DB but adds XRay tracing.
// It must be created with xray.SQL
type DB struct {
	db *sql.DB
	x  *XRay

	connectionString string
	url              string
	databaseType     string
	databaseVersion  string
	driverVersion    string
	user             string
	dbname           string
}

// Close simply calls the wrapped database Close()
func (db *DB) Close() error { return db.db.Close() }

// Driver simply calls the wrapped database Driver()
func (db *DB) Driver() driver.Driver { return db.db.Driver() }

// Stats simply calls the wrapped database Stats()
func (db *DB) Stats() sql.DBStats { return db.db.Stats() }

// SetConnMaxLifetime simply calls the wrapped database SetConnMaxLifetime()
func (db *DB) SetConnMaxLifetime(d time.Duration) { db.db.SetConnMaxLifetime(d) }

// SetMaxIdleConns simply calls the wrapped database SetMaxIdleConns()
func (db *DB) SetMaxIdleConns(n int) { db.db.SetMaxIdleConns(n) }

// SetMaxOpenConns simply calls the wrapped database SetMaxOpenConns()
func (db *DB) SetMaxOpenConns(n int) { db.db.SetMaxOpenConns(n) }

func (db *DB) populate(ctx context.Context, query string) {
	seg := getSegment(ctx)

	seg.Lock()
	seg.Namespace = "remote"
	seg.sql().ConnectionString = db.connectionString
	seg.sql().URL = db.url
	seg.sql().DatabaseType = db.databaseType
	seg.sql().DatabaseVersion = db.databaseVersion
	seg.sql().DriverVersion = db.driverVersion
	seg.sql().User = db.user
	seg.sql().SanitizedQuery = query
	seg.Unlock()
}

// Tx copies the interface of sql.Tx but adds XRay tracing.
// It must be created with xray.DB.Begin
type Tx struct {
	db *DB
	tx *sql.Tx
}

// Commit simply calls the wrapped transaction's Commit()
func (tx *Tx) Commit() error { return tx.tx.Commit() }

// Rollback simply calls the wrapped transaction's Rollback()
func (tx *Tx) Rollback() error { return tx.tx.Rollback() }

// Stmt copies the interface of sql.Stmt but adds XRay tracing.
// It must be created with xray.DB.Prepare or xray.Tx.Stmt
type Stmt struct {
	db    *DB
	stmt  *sql.Stmt
	query string
}

// Close simply calls the wrapped statement's Close()
func (stmt *Stmt) Close() error { return stmt.stmt.Close() }

func (stmt *Stmt) populate(ctx context.Context, query string) {
	stmt.db.populate(ctx, query)

	seg := getSegment(ctx)
	seg.Lock()
	seg.sql().Preparation = "statement"
	seg.Unlock()
}

func postgresDetector(db *DB) error {
	db.databaseType = "Postgres"
	row := db.db.QueryRow("SELECT version(), current_user, current_database()")
	return row.Scan(&db.databaseVersion, &db.user, &db.dbname)
}

func mysqlDetector(db *DB) error {
	db.databaseType = "MySQL"
	row := db.db.QueryRow("SELECT version(), current_user(), database()")
	return row.Scan(&db.databaseVersion, &db.user, &db.dbname)
}

func mssqlDetector(db *DB) error {
	db.databaseType = "MS SQL"
	row := db.db.QueryRow("SELECT @@version, current_user, db_name()")
	return row.Scan(&db.databaseVersion, &db.user, &db.dbname)
}

func oracleDetector(db *DB) error {
	db.databaseType = "Oracle"
	row := db.db.QueryRow("SELECT version FROM v$instance UNION SELECT user, ora_database_name FROM dual")
	return row.Scan(&db.databaseVersion, &db.user, &db.dbname)
}

func stripPasswordsFlushFunc(inBraces *bool, isPassword *bool, tok *bytes.Buffer, res io.Writer) func() error {
	return func() error {
		if *inBraces {
			return nil
		}
		if !*isPassword {
			_, err := res.Write(tok.Bytes())
			if err != nil {
				return err
			}
		}
		tok.Reset()
		*isPassword = false
		return nil
	}
}

func onStripPasswordsCurly(buf *strings.Reader, tok io.ByteWriter, inBraces *bool) error {
	if buf.Len() == 0 {
		return nil
	}
	b, err := buf.ReadByte()
	if err != nil {
		// We should have a byte to read by here.  This would be very strange
		return err
	}
	if b == '}' {
		if err := tok.WriteByte(b); err != nil {
			return err
		}
	} else {
		*inBraces = false
		if err := buf.UnreadByte(); err != nil {
			return err
		}
	}
	return nil
}

func onStripPasswordsEquals(buf io.ByteScanner, tok fmt.Stringer, isPassword *bool, inBraces *bool) error {
	tokStr := strings.ToLower(tok.String())
	*isPassword = `password=` == tokStr || `pwd=` == tokStr
	if b, err := buf.ReadByte(); err == nil && b == '{' {
		*inBraces = true
	}
	if err := buf.UnreadByte(); err != nil {
		return err
	}
	return nil
}

func stripPasswords(dsn string) (string, error) {
	var (
		tok        bytes.Buffer
		res        bytes.Buffer
		isPassword bool
		inBraces   bool
		delimiter  byte = ' '
	)
	flush := stripPasswordsFlushFunc(&inBraces, &isPassword, &tok, &res)
	// If anybody has any better ideas, I'm all ears
	if strings.Count(dsn, ";") > strings.Count(dsn, " ") {
		delimiter = ';'
	}

	buf := strings.NewReader(dsn)
	for c, err := buf.ReadByte(); err == nil; c, err = buf.ReadByte() {
		tok.WriteByte(c)
		switch c {
		case ':', delimiter:
			if err := flush(); err != nil {
				return "", err
			}
		case '=':
			if err := onStripPasswordsEquals(buf, &tok, &isPassword, &inBraces); err != nil {
				return "", err
			}
		case '}':
			if err := onStripPasswordsCurly(buf, &tok, &inBraces); err != nil {
				return "", err
			}
		}
	}
	inBraces = false
	if err := flush(); err != nil {
		return "", err
	}
	return res.String(), nil
}

func enhanceQuery(ctx context.Context, query string) string {
	seg := getSegment(ctx)
	return fmt.Sprintf("%s /* XRay: Trace ID = %s, Segment ID = %s */", query, seg.root().TraceID, seg.ID)
}
