diff --git a/internal/handlers/common_test.go b/internal/handlers/common_test.go index 6cbd9ef..87fd7dc 100644 --- a/internal/handlers/common_test.go +++ b/internal/handlers/common_test.go @@ -258,10 +258,7 @@ func RunTests(m *testing.M) int { } defer db.Close() - err = db.DbMap.TruncateTables() - if err != nil { - log.Fatal(err) - } + db.Empty() // clear the DB tables server = httptest.NewTLSServer(&handlers.APIHandler{Store: db}) defer server.Close() diff --git a/internal/store/db/db.go b/internal/store/db/db.go index d8b043b..ba01e61 100644 --- a/internal/store/db/db.go +++ b/internal/store/db/db.go @@ -19,7 +19,7 @@ import ( // implementation's string type specified by the same. const luaMaxLengthBuffer int = 4096 -func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { +func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { var dialect gorp.Dialect if dbtype == config.SQLite { dialect = gorp.SqliteDialect{} @@ -55,7 +55,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { return dbmap, nil } -func GetDSN(dbtype config.DbType, dsn string) string { +func getDSN(dbtype config.DbType, dsn string) string { if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") { log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!") } @@ -63,25 +63,29 @@ func GetDSN(dbtype config.DbType, dsn string) string { } type DbStore struct { - DbMap *gorp.DbMap + dbMap *gorp.DbMap +} + +func (db *DbStore) Empty() error { + return db.dbMap.TruncateTables() } func (db *DbStore) Begin() (store.Tx, error) { - tx, err := db.DbMap.Begin() + tx, err := db.dbMap.Begin() if err != nil { return nil, err } - return &Tx{db.DbMap.Dialect, tx}, nil + return &Tx{db.dbMap.Dialect, tx}, nil } func (db *DbStore) Close() error { - err := db.DbMap.Db.Close() - db.DbMap = nil + err := db.dbMap.Db.Close() + db.dbMap = nil return err } func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) { - dsn = GetDSN(dbtype, dsn) + dsn = getDSN(dbtype, dsn) database, err := sql.Open(dbtype.String(), dsn) if err != nil { return nil, err @@ -92,7 +96,7 @@ func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) { } }() - dbmap, err := GetDbMap(database, dbtype) + dbmap, err := getDbMap(database, dbtype) if err != nil { return nil, err } diff --git a/internal/store/store.go b/internal/store/store.go index 3f87880..412b7c6 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -117,6 +117,7 @@ type Tx interface { } type Store interface { + Empty() error Begin() (Tx, error) Close() error }