diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 9ef9521..d5ddcd8 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -62,7 +62,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou var tradingAccount models.Account var account models.Account - user, err := GetUser(tx, userid) + user, err := tx.GetUser(userid) if err != nil { return nil, err } @@ -79,7 +79,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou return nil, err } - security, err := GetSecurity(tx, securityid, userid) + security, err := tx.GetSecurity(securityid, userid) if err != nil { return nil, err } @@ -124,7 +124,7 @@ func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Acc return nil, err } - security, err := GetSecurity(tx, securityid, userid) + security, err := tx.GetSecurity(securityid, userid) if err != nil { return nil, err } @@ -280,7 +280,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { account.UserId = user.UserId account.AccountVersion = 0 - security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) + security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -341,7 +341,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { } account.UserId = user.UserId - security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) + security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 08267d3..ea74042 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -159,7 +159,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) split := new(models.Split) r := new(big.Rat) r.Neg(&imbalance) - security, err := GetSecurity(tx, imbalanced_security, user.UserId) + security, err := tx.GetSecurity(imbalanced_security, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 620150d..f08d37e 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -97,7 +97,7 @@ func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time. } func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter { - security, err := GetSecurity(context.Tx, securityid, user.UserId) + security, err := context.Tx.GetSecurity(securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -112,7 +112,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security if price.SecurityId != security.SecurityId { return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) + _, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -161,11 +161,11 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(context.Tx, price.SecurityId, user.UserId) + _, err = context.Tx.GetSecurity(price.SecurityId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) + _, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index e836822..ec9544b 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -4,7 +4,6 @@ package handlers import ( "errors" - "fmt" "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/store/db" "log" @@ -51,90 +50,18 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security { return nil } -func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) { - var s models.Security - - err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) - if err != nil { - return nil, err - } - return &s, nil -} - -func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) { - var securities []*models.Security - - _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) - if err != nil { - return nil, err - } - return &securities, nil -} - -func InsertSecurity(tx *db.Tx, s *models.Security) error { - err := tx.Insert(s) - if err != nil { - return err - } - return nil -} - func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) { - user, err := GetUser(tx, s.UserId) + user, err := tx.GetUser(s.UserId) if err != nil { return } else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency { return errors.New("Cannot change security which is user's default currency to be non-currency") } - count, err := tx.Update(s) + err = tx.UpdateSecurity(s) if err != nil { return } - if count > 1 { - return fmt.Errorf("Updated %d securities (expected 1)", count) - } - - return nil -} - -type SecurityInUseError struct { - message string -} - -func (e SecurityInUseError) Error() string { - return e.message -} - -func DeleteSecurity(tx *db.Tx, s *models.Security) error { - // First, ensure no accounts are using this security - accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) - - if accounts != 0 { - return SecurityInUseError{"One or more accounts still use this security"} - } - - user, err := GetUser(tx, s.UserId) - if err != nil { - return err - } else if user.DefaultCurrency == s.SecurityId { - return SecurityInUseError{"Cannot delete security which is user's default currency"} - } - - // Remove all prices involving this security (either of this security, or - // using it as a currency) - _, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) - if err != nil { - return err - } - - count, err := tx.Delete(s) - if err != nil { - return err - } - if count != 1 { - return errors.New("Deleted more than one security") - } return nil } @@ -143,16 +70,14 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId - err := InsertSecurity(tx, security) + err := tx.InsertSecurity(security) if err != nil { return nil, err } return security, nil } - var securities []*models.Security - - _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) + securities, err := tx.FindMatchingSecurities(userid, security) if err != nil { return nil, err } @@ -160,7 +85,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) // First try to find a case insensitive match on the name or symbol upperName := strings.ToUpper(security.Name) upperSymbol := strings.ToUpper(security.Symbol) - for _, s := range securities { + for _, s := range *securities { if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) || (len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) { return s, nil @@ -169,7 +94,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) // if strings.Contains(strings.ToUpper(security.Name), upperSearch) || // Try to find a partial string match on the name or symbol - for _, s := range securities { + for _, s := range *securities { sUpperName := strings.ToUpper(s.Name) sUpperSymbol := strings.ToUpper(s.Symbol) if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) || @@ -179,12 +104,12 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) } // Give up and return the first security in the list - if len(securities) > 0 { - return securities[0], nil + if len(*securities) > 0 { + return (*securities)[0], nil } // If there wasn't even one security in the list, make a new one - err = InsertSecurity(tx, security) + err = tx.InsertSecurity(security) if err != nil { return nil, err } @@ -217,7 +142,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { security.SecurityId = -1 security.UserId = user.UserId - err = InsertSecurity(context.Tx, &security) + err = context.Tx.InsertSecurity(&security) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -229,7 +154,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { //Return all securities var sl models.SecurityList - securities, err := GetSecurities(context.Tx, user.UserId) + securities, err := context.Tx.GetSecurities(user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -250,7 +175,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { return PriceHandler(r, context, user, securityid) } - security, err := GetSecurity(context.Tx, securityid, user.UserId) + security, err := context.Tx.GetSecurity(securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -284,13 +209,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { return &security } else if r.Method == "DELETE" { - security, err := GetSecurity(context.Tx, securityid, user.UserId) + security, err := context.Tx.GetSecurity(securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteSecurity(context.Tx, security) - if _, ok := err.(SecurityInUseError); ok { + err = context.Tx.DeleteSecurity(security) + if _, ok := err.(db.SecurityInUseError); ok { return NewError(7 /*In Use Error*/) } else if err != nil { log.Print(err) diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index a5307c6..78716f2 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -27,7 +27,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) return nil, errors.New("Couldn't find User in lua's Context") } - securities, err := GetSecurities(tx, user.UserId) + securities, err := tx.GetSecurities(user.UserId) if err != nil { return nil, err } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index f4d6c5e..71deff4 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -85,12 +85,15 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - dbuser, err := GetUserByUsername(context.Tx, user.Username) + // Hash password before checking username to help mitigate timing + // attacks + user.HashPassword() + + dbuser, err := context.StoreTx.GetUserByUsername(user.Username) if err != nil { return NewError(2 /*Unauthorized Access*/) } - user.HashPassword() if user.PasswordHash != dbuser.PasswordHash { return NewError(2 /*Unauthorized Access*/) } diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 4ec94d1..707d29a 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -542,7 +542,7 @@ func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort } atl.TotalTransactions = count - security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId) + security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId) if err != nil { return nil, err } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 7225fe5..e9a468d 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -2,9 +2,8 @@ package handlers import ( "errors" - "fmt" "github.com/aclindsa/moneygo/internal/models" - "github.com/aclindsa/moneygo/internal/store/db" + "github.com/aclindsa/moneygo/internal/store" "log" "net/http" ) @@ -15,41 +14,21 @@ func (ueu UserExistsError) Error() string { return "User exists" } -func GetUser(tx *db.Tx, userid int64) (*models.User, error) { - var u models.User - - err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) - if err != nil { - return nil, err - } - return &u, nil -} - -func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) { - var u models.User - - err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) - if err != nil { - return nil, err - } - return &u, nil -} - -func InsertUser(tx *db.Tx, u *models.User) error { +func InsertUser(tx store.Tx, u *models.User) error { security_template := FindCurrencyTemplate(u.DefaultCurrency) if security_template == nil { return errors.New("Invalid ISO4217 Default Currency") } - existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username) + exists, err := tx.UsernameExists(u.Username) if err != nil { return err } - if existing > 0 { + if exists { return UserExistsError{} } - err = tx.Insert(u) + err = tx.InsertUser(u) if err != nil { return err } @@ -59,33 +38,31 @@ func InsertUser(tx *db.Tx, u *models.User) error { security = *security_template security.UserId = u.UserId - err = InsertSecurity(tx, &security) + err = tx.InsertSecurity(&security) if err != nil { return err } // Update the user's DefaultCurrency to our new SecurityId u.DefaultCurrency = security.SecurityId - count, err := tx.Update(u) + err = tx.UpdateUser(u) if err != nil { return err - } else if count != 1 { - return errors.New("Would have updated more than one user") } return nil } -func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) { +func GetUserFromSession(tx store.Tx, r *http.Request) (*models.User, error) { s, err := GetSession(tx, r) if err != nil { return nil, err } - return GetUser(tx, s.UserId) + return tx.GetUser(s.UserId) } -func UpdateUser(tx *db.Tx, u *models.User) error { - security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) +func UpdateUser(tx store.Tx, u *models.User) error { + security, err := tx.GetSecurity(u.DefaultCurrency, u.UserId) if err != nil { return err } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { @@ -94,49 +71,7 @@ func UpdateUser(tx *db.Tx, u *models.User) error { return errors.New("New DefaultCurrency security is not a currency") } - count, err := tx.Update(u) - if err != nil { - return err - } else if count != 1 { - return errors.New("Would have updated more than one user") - } - - return nil -} - -func DeleteUser(tx *db.Tx, u *models.User) error { - count, err := tx.Delete(u) - if err != nil { - return err - } - if count != 1 { - return fmt.Errorf("No user to delete") - } - _, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) + err = tx.UpdateUser(u) if err != nil { return err } @@ -205,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { return user } else if r.Method == "DELETE" { - err := DeleteUser(context.Tx, user) + err := context.StoreTx.DeleteUser(user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/store/db/securities.go b/internal/store/db/securities.go new file mode 100644 index 0000000..0acba29 --- /dev/null +++ b/internal/store/db/securities.go @@ -0,0 +1,95 @@ +package db + +import ( + "fmt" + "github.com/aclindsa/moneygo/internal/models" +) + +type SecurityInUseError struct { + message string +} + +func (e SecurityInUseError) Error() string { + return e.message +} + +func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) { + var s models.Security + + err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) + if err != nil { + return nil, err + } + return &s, nil +} + +func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) { + var securities []*models.Security + + _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) + if err != nil { + return nil, err + } + return &securities, nil +} + +func (tx *Tx) FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) { + var securities []*models.Security + + _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) + if err != nil { + return nil, err + } + return &securities, nil +} + +func (tx *Tx) InsertSecurity(s *models.Security) error { + err := tx.Insert(s) + if err != nil { + return err + } + return nil +} + +func (tx *Tx) UpdateSecurity(security *models.Security) error { + count, err := tx.Update(security) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to update 1 security, was going to update %d", count) + } + return nil +} + +func (tx *Tx) DeleteSecurity(s *models.Security) error { + // First, ensure no accounts are using this security + accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) + + if accounts != 0 { + return SecurityInUseError{"One or more accounts still use this security"} + } + + user, err := tx.GetUser(s.UserId) + if err != nil { + return err + } else if user.DefaultCurrency == s.SecurityId { + return SecurityInUseError{"Cannot delete security which is user's default currency"} + } + + // Remove all prices involving this security (either of this security, or + // using it as a currency) + _, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) + if err != nil { + return err + } + + count, err := tx.Delete(s) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to delete 1 security, was going to delete %d", count) + } + return nil +} diff --git a/internal/store/db/sessions.go b/internal/store/db/sessions.go index 671c55a..57a5ad5 100644 --- a/internal/store/db/sessions.go +++ b/internal/store/db/sessions.go @@ -31,6 +31,12 @@ func (tx *Tx) SessionExists(secret string) (bool, error) { } func (tx *Tx) DeleteSession(session *models.Session) error { - _, err := tx.Delete(session) - return err + count, err := tx.Delete(session) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count) + } + return nil } diff --git a/internal/store/db/users.go b/internal/store/db/users.go new file mode 100644 index 0000000..2a44b23 --- /dev/null +++ b/internal/store/db/users.go @@ -0,0 +1,86 @@ +package db + +import ( + "fmt" + "github.com/aclindsa/moneygo/internal/models" +) + +func (tx *Tx) UsernameExists(username string) (bool, error) { + existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", username) + return existing != 0, err +} + +func (tx *Tx) InsertUser(user *models.User) error { + return tx.Insert(user) +} + +func (tx *Tx) GetUser(userid int64) (*models.User, error) { + var u models.User + + err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) + if err != nil { + return nil, err + } + return &u, nil +} + +func (tx *Tx) GetUserByUsername(username string) (*models.User, error) { + var u models.User + + err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) + if err != nil { + return nil, err + } + return &u, nil +} + +func (tx *Tx) UpdateUser(user *models.User) error { + count, err := tx.Update(user) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to update 1 user, was going to update %d", count) + } + return nil +} + +func (tx *Tx) DeleteUser(user *models.User) error { + count, err := tx.Delete(user) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count) + } + _, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", user.UserId) + if err != nil { + return err + } + + return nil +} diff --git a/internal/store/store.go b/internal/store/store.go index 9823236..86d6c66 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -5,17 +5,37 @@ import ( ) type SessionStore interface { + SessionExists(secret string) (bool, error) InsertSession(session *models.Session) error GetSession(secret string) (*models.Session, error) - SessionExists(secret string) (bool, error) DeleteSession(session *models.Session) error } +type UserStore interface { + UsernameExists(username string) (bool, error) + InsertUser(user *models.User) error + GetUser(userid int64) (*models.User, error) + GetUserByUsername(username string) (*models.User, error) + UpdateUser(user *models.User) error + DeleteUser(user *models.User) error +} + +type SecurityStore interface { + InsertSecurity(security *models.Security) error + GetSecurity(securityid int64, userid int64) (*models.Security, error) + GetSecurities(userid int64) (*[]*models.Security, error) + FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) + UpdateSecurity(security *models.Security) error + DeleteSecurity(security *models.Security) error +} + type Tx interface { Commit() error Rollback() error SessionStore + UserStore + SecurityStore } type Store interface {