mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-12-26 07:33:21 -05:00
Move users and securities to store
This commit is contained in:
parent
c452984f23
commit
bec5152e53
@ -62,7 +62,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou
|
|||||||
var tradingAccount models.Account
|
var tradingAccount models.Account
|
||||||
var account models.Account
|
var account models.Account
|
||||||
|
|
||||||
user, err := GetUser(tx, userid)
|
user, err := tx.GetUser(userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -79,7 +79,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
security, err := GetSecurity(tx, securityid, userid)
|
security, err := tx.GetSecurity(securityid, userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -124,7 +124,7 @@ func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Acc
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
security, err := GetSecurity(tx, securityid, userid)
|
security, err := tx.GetSecurity(securityid, userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -280,7 +280,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
account.UserId = user.UserId
|
account.UserId = user.UserId
|
||||||
account.AccountVersion = 0
|
account.AccountVersion = 0
|
||||||
|
|
||||||
security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId)
|
security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -341,7 +341,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
}
|
}
|
||||||
account.UserId = user.UserId
|
account.UserId = user.UserId
|
||||||
|
|
||||||
security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId)
|
security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -159,7 +159,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64)
|
|||||||
split := new(models.Split)
|
split := new(models.Split)
|
||||||
r := new(big.Rat)
|
r := new(big.Rat)
|
||||||
r.Neg(&imbalance)
|
r.Neg(&imbalance)
|
||||||
security, err := GetSecurity(tx, imbalanced_security, user.UserId)
|
security, err := tx.GetSecurity(imbalanced_security, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -97,7 +97,7 @@ func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
|
func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
|
||||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -112,7 +112,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
|||||||
if price.SecurityId != security.SecurityId {
|
if price.SecurityId != security.SecurityId {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId)
|
_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -161,11 +161,11 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = GetSecurity(context.Tx, price.SecurityId, user.UserId)
|
_, err = context.Tx.GetSecurity(price.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId)
|
_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"github.com/aclindsa/moneygo/internal/store/db"
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"log"
|
"log"
|
||||||
@ -51,90 +50,18 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) {
|
|
||||||
var s models.Security
|
|
||||||
|
|
||||||
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) {
|
|
||||||
var securities []*models.Security
|
|
||||||
|
|
||||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &securities, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertSecurity(tx *db.Tx, s *models.Security) error {
|
|
||||||
err := tx.Insert(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) {
|
func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) {
|
||||||
user, err := GetUser(tx, s.UserId)
|
user, err := tx.GetUser(s.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency {
|
} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency {
|
||||||
return errors.New("Cannot change security which is user's default currency to be non-currency")
|
return errors.New("Cannot change security which is user's default currency to be non-currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := tx.Update(s)
|
err = tx.UpdateSecurity(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if count > 1 {
|
|
||||||
return fmt.Errorf("Updated %d securities (expected 1)", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type SecurityInUseError struct {
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e SecurityInUseError) Error() string {
|
|
||||||
return e.message
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteSecurity(tx *db.Tx, s *models.Security) error {
|
|
||||||
// First, ensure no accounts are using this security
|
|
||||||
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
|
||||||
|
|
||||||
if accounts != 0 {
|
|
||||||
return SecurityInUseError{"One or more accounts still use this security"}
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := GetUser(tx, s.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if user.DefaultCurrency == s.SecurityId {
|
|
||||||
return SecurityInUseError{"Cannot delete security which is user's default currency"}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove all prices involving this security (either of this security, or
|
|
||||||
// using it as a currency)
|
|
||||||
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
count, err := tx.Delete(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return errors.New("Deleted more than one security")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -143,16 +70,14 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
|
|||||||
security.UserId = userid
|
security.UserId = userid
|
||||||
if len(security.AlternateId) == 0 {
|
if len(security.AlternateId) == 0 {
|
||||||
// Always create a new local security if we can't match on the AlternateId
|
// Always create a new local security if we can't match on the AlternateId
|
||||||
err := InsertSecurity(tx, security)
|
err := tx.InsertSecurity(security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return security, nil
|
return security, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var securities []*models.Security
|
securities, err := tx.FindMatchingSecurities(userid, security)
|
||||||
|
|
||||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -160,7 +85,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
|
|||||||
// First try to find a case insensitive match on the name or symbol
|
// First try to find a case insensitive match on the name or symbol
|
||||||
upperName := strings.ToUpper(security.Name)
|
upperName := strings.ToUpper(security.Name)
|
||||||
upperSymbol := strings.ToUpper(security.Symbol)
|
upperSymbol := strings.ToUpper(security.Symbol)
|
||||||
for _, s := range securities {
|
for _, s := range *securities {
|
||||||
if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
|
if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
|
||||||
(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
|
(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
|
||||||
return s, nil
|
return s, nil
|
||||||
@ -169,7 +94,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
|
|||||||
// if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
|
// if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
|
||||||
|
|
||||||
// Try to find a partial string match on the name or symbol
|
// Try to find a partial string match on the name or symbol
|
||||||
for _, s := range securities {
|
for _, s := range *securities {
|
||||||
sUpperName := strings.ToUpper(s.Name)
|
sUpperName := strings.ToUpper(s.Name)
|
||||||
sUpperSymbol := strings.ToUpper(s.Symbol)
|
sUpperSymbol := strings.ToUpper(s.Symbol)
|
||||||
if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
|
if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
|
||||||
@ -179,12 +104,12 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Give up and return the first security in the list
|
// Give up and return the first security in the list
|
||||||
if len(securities) > 0 {
|
if len(*securities) > 0 {
|
||||||
return securities[0], nil
|
return (*securities)[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there wasn't even one security in the list, make a new one
|
// If there wasn't even one security in the list, make a new one
|
||||||
err = InsertSecurity(tx, security)
|
err = tx.InsertSecurity(security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -217,7 +142,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
security.SecurityId = -1
|
security.SecurityId = -1
|
||||||
security.UserId = user.UserId
|
security.UserId = user.UserId
|
||||||
|
|
||||||
err = InsertSecurity(context.Tx, &security)
|
err = context.Tx.InsertSecurity(&security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -229,7 +154,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
//Return all securities
|
//Return all securities
|
||||||
var sl models.SecurityList
|
var sl models.SecurityList
|
||||||
|
|
||||||
securities, err := GetSecurities(context.Tx, user.UserId)
|
securities, err := context.Tx.GetSecurities(user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -250,7 +175,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return PriceHandler(r, context, user, securityid)
|
return PriceHandler(r, context, user, securityid)
|
||||||
}
|
}
|
||||||
|
|
||||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -284,13 +209,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
|
|
||||||
return &security
|
return &security
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteSecurity(context.Tx, security)
|
err = context.Tx.DeleteSecurity(security)
|
||||||
if _, ok := err.(SecurityInUseError); ok {
|
if _, ok := err.(db.SecurityInUseError); ok {
|
||||||
return NewError(7 /*In Use Error*/)
|
return NewError(7 /*In Use Error*/)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
@ -27,7 +27,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
|
|||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
securities, err := GetSecurities(tx, user.UserId)
|
securities, err := tx.GetSecurities(user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -85,12 +85,15 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbuser, err := GetUserByUsername(context.Tx, user.Username)
|
// Hash password before checking username to help mitigate timing
|
||||||
|
// attacks
|
||||||
|
user.HashPassword()
|
||||||
|
|
||||||
|
dbuser, err := context.StoreTx.GetUserByUsername(user.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(2 /*Unauthorized Access*/)
|
return NewError(2 /*Unauthorized Access*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
user.HashPassword()
|
|
||||||
if user.PasswordHash != dbuser.PasswordHash {
|
if user.PasswordHash != dbuser.PasswordHash {
|
||||||
return NewError(2 /*Unauthorized Access*/)
|
return NewError(2 /*Unauthorized Access*/)
|
||||||
}
|
}
|
||||||
|
@ -542,7 +542,7 @@ func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort
|
|||||||
}
|
}
|
||||||
atl.TotalTransactions = count
|
atl.TotalTransactions = count
|
||||||
|
|
||||||
security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId)
|
security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -2,9 +2,8 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"github.com/aclindsa/moneygo/internal/store/db"
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@ -15,41 +14,21 @@ func (ueu UserExistsError) Error() string {
|
|||||||
return "User exists"
|
return "User exists"
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(tx *db.Tx, userid int64) (*models.User, error) {
|
func InsertUser(tx store.Tx, u *models.User) error {
|
||||||
var u models.User
|
|
||||||
|
|
||||||
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) {
|
|
||||||
var u models.User
|
|
||||||
|
|
||||||
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertUser(tx *db.Tx, u *models.User) error {
|
|
||||||
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
||||||
if security_template == nil {
|
if security_template == nil {
|
||||||
return errors.New("Invalid ISO4217 Default Currency")
|
return errors.New("Invalid ISO4217 Default Currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username)
|
exists, err := tx.UsernameExists(u.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if existing > 0 {
|
if exists {
|
||||||
return UserExistsError{}
|
return UserExistsError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Insert(u)
|
err = tx.InsertUser(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -59,33 +38,31 @@ func InsertUser(tx *db.Tx, u *models.User) error {
|
|||||||
security = *security_template
|
security = *security_template
|
||||||
security.UserId = u.UserId
|
security.UserId = u.UserId
|
||||||
|
|
||||||
err = InsertSecurity(tx, &security)
|
err = tx.InsertSecurity(&security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the user's DefaultCurrency to our new SecurityId
|
// Update the user's DefaultCurrency to our new SecurityId
|
||||||
u.DefaultCurrency = security.SecurityId
|
u.DefaultCurrency = security.SecurityId
|
||||||
count, err := tx.Update(u)
|
err = tx.UpdateUser(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if count != 1 {
|
|
||||||
return errors.New("Would have updated more than one user")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) {
|
func GetUserFromSession(tx store.Tx, r *http.Request) (*models.User, error) {
|
||||||
s, err := GetSession(tx, r)
|
s, err := GetSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return GetUser(tx, s.UserId)
|
return tx.GetUser(s.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(tx *db.Tx, u *models.User) error {
|
func UpdateUser(tx store.Tx, u *models.User) error {
|
||||||
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
|
security, err := tx.GetSecurity(u.DefaultCurrency, u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
|
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
|
||||||
@ -94,49 +71,7 @@ func UpdateUser(tx *db.Tx, u *models.User) error {
|
|||||||
return errors.New("New DefaultCurrency security is not a currency")
|
return errors.New("New DefaultCurrency security is not a currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := tx.Update(u)
|
err = tx.UpdateUser(u)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if count != 1 {
|
|
||||||
return errors.New("Would have updated more than one user")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteUser(tx *db.Tx, u *models.User) error {
|
|
||||||
count, err := tx.Delete(u)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return fmt.Errorf("No user to delete")
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -205,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
|
|
||||||
return user
|
return user
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
err := DeleteUser(context.Tx, user)
|
err := context.StoreTx.DeleteUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
95
internal/store/db/securities.go
Normal file
95
internal/store/db/securities.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SecurityInUseError struct {
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e SecurityInUseError) Error() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) {
|
||||||
|
var s models.Security
|
||||||
|
|
||||||
|
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) {
|
||||||
|
var securities []*models.Security
|
||||||
|
|
||||||
|
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &securities, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) {
|
||||||
|
var securities []*models.Security
|
||||||
|
|
||||||
|
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &securities, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) InsertSecurity(s *models.Security) error {
|
||||||
|
err := tx.Insert(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) UpdateSecurity(security *models.Security) error {
|
||||||
|
count, err := tx.Update(security)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to update 1 security, was going to update %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteSecurity(s *models.Security) error {
|
||||||
|
// First, ensure no accounts are using this security
|
||||||
|
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
||||||
|
|
||||||
|
if accounts != 0 {
|
||||||
|
return SecurityInUseError{"One or more accounts still use this security"}
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := tx.GetUser(s.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if user.DefaultCurrency == s.SecurityId {
|
||||||
|
return SecurityInUseError{"Cannot delete security which is user's default currency"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove all prices involving this security (either of this security, or
|
||||||
|
// using it as a currency)
|
||||||
|
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := tx.Delete(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to delete 1 security, was going to delete %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -31,6 +31,12 @@ func (tx *Tx) SessionExists(secret string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) DeleteSession(session *models.Session) error {
|
func (tx *Tx) DeleteSession(session *models.Session) error {
|
||||||
_, err := tx.Delete(session)
|
count, err := tx.Delete(session)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
86
internal/store/db/users.go
Normal file
86
internal/store/db/users.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) UsernameExists(username string) (bool, error) {
|
||||||
|
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", username)
|
||||||
|
return existing != 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) InsertUser(user *models.User) error {
|
||||||
|
return tx.Insert(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetUser(userid int64) (*models.User, error) {
|
||||||
|
var u models.User
|
||||||
|
|
||||||
|
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetUserByUsername(username string) (*models.User, error) {
|
||||||
|
var u models.User
|
||||||
|
|
||||||
|
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) UpdateUser(user *models.User) error {
|
||||||
|
count, err := tx.Update(user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to update 1 user, was going to update %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteUser(user *models.User) error {
|
||||||
|
count, err := tx.Delete(user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -5,17 +5,37 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type SessionStore interface {
|
type SessionStore interface {
|
||||||
|
SessionExists(secret string) (bool, error)
|
||||||
InsertSession(session *models.Session) error
|
InsertSession(session *models.Session) error
|
||||||
GetSession(secret string) (*models.Session, error)
|
GetSession(secret string) (*models.Session, error)
|
||||||
SessionExists(secret string) (bool, error)
|
|
||||||
DeleteSession(session *models.Session) error
|
DeleteSession(session *models.Session) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UserStore interface {
|
||||||
|
UsernameExists(username string) (bool, error)
|
||||||
|
InsertUser(user *models.User) error
|
||||||
|
GetUser(userid int64) (*models.User, error)
|
||||||
|
GetUserByUsername(username string) (*models.User, error)
|
||||||
|
UpdateUser(user *models.User) error
|
||||||
|
DeleteUser(user *models.User) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SecurityStore interface {
|
||||||
|
InsertSecurity(security *models.Security) error
|
||||||
|
GetSecurity(securityid int64, userid int64) (*models.Security, error)
|
||||||
|
GetSecurities(userid int64) (*[]*models.Security, error)
|
||||||
|
FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error)
|
||||||
|
UpdateSecurity(security *models.Security) error
|
||||||
|
DeleteSecurity(security *models.Security) error
|
||||||
|
}
|
||||||
|
|
||||||
type Tx interface {
|
type Tx interface {
|
||||||
Commit() error
|
Commit() error
|
||||||
Rollback() error
|
Rollback() error
|
||||||
|
|
||||||
SessionStore
|
SessionStore
|
||||||
|
UserStore
|
||||||
|
SecurityStore
|
||||||
}
|
}
|
||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
|
Loading…
Reference in New Issue
Block a user