diff --git a/internal/handlers/common_test.go b/internal/handlers/common_test.go index acd87f2..071ec05 100644 --- a/internal/handlers/common_test.go +++ b/internal/handlers/common_test.go @@ -174,14 +174,41 @@ func RunWith(t *testing.T, d *TestData, fn TestDataFunc) { } func RunTests(m *testing.M) int { - dsn := db.GetDSN(config.SQLite, ":memory:") - database, err := sql.Open("sqlite3", dsn) + envDbType := os.Getenv("MONEYGO_TEST_DB") + var dbType config.DbType + var dsn string + + switch envDbType { + case "", "sqlite", "sqlite3": + dbType = config.SQLite + dsn = ":memory:" + case "mariadb", "mysql": + dbType = config.MySQL + dsn = "root@127.0.0.1/moneygo_test&parseTime=true" + case "postgres", "postgresql": + dbType = config.Postgres + dsn = "postgres://postgres@localhost/moneygo_test" + default: + log.Fatalf("Invalid value for $MONEYGO_TEST_DB: %s\n", envDbType) + } + + if envDSN := os.Getenv("MONEYGO_TEST_DSN"); len(envDSN) > 0 { + dsn = envDSN + } + + dsn = db.GetDSN(dbType, dsn) + database, err := sql.Open(dbType.String(), dsn) if err != nil { log.Fatal(err) } defer database.Close() - dbmap, err := db.GetDbMap(database, config.SQLite) + dbmap, err := db.GetDbMap(database, dbType) + if err != nil { + log.Fatal(err) + } + + err = dbmap.TruncateTables() if err != nil { log.Fatal(err) }