mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-11-03 18:13:27 -05:00 
			
		
		
		
	Move users and securities to store
This commit is contained in:
		@@ -62,7 +62,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou
 | 
				
			|||||||
	var tradingAccount models.Account
 | 
						var tradingAccount models.Account
 | 
				
			||||||
	var account models.Account
 | 
						var account models.Account
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	user, err := GetUser(tx, userid)
 | 
						user, err := tx.GetUser(userid)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -79,7 +79,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	security, err := GetSecurity(tx, securityid, userid)
 | 
						security, err := tx.GetSecurity(securityid, userid)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -124,7 +124,7 @@ func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Acc
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	security, err := GetSecurity(tx, securityid, userid)
 | 
						security, err := tx.GetSecurity(securityid, userid)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -280,7 +280,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
				
			|||||||
		account.UserId = user.UserId
 | 
							account.UserId = user.UserId
 | 
				
			||||||
		account.AccountVersion = 0
 | 
							account.AccountVersion = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId)
 | 
							security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Print(err)
 | 
								log.Print(err)
 | 
				
			||||||
			return NewError(999 /*Internal Error*/)
 | 
								return NewError(999 /*Internal Error*/)
 | 
				
			||||||
@@ -341,7 +341,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			account.UserId = user.UserId
 | 
								account.UserId = user.UserId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId)
 | 
								security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				log.Print(err)
 | 
									log.Print(err)
 | 
				
			||||||
				return NewError(999 /*Internal Error*/)
 | 
									return NewError(999 /*Internal Error*/)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -159,7 +159,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64)
 | 
				
			|||||||
				split := new(models.Split)
 | 
									split := new(models.Split)
 | 
				
			||||||
				r := new(big.Rat)
 | 
									r := new(big.Rat)
 | 
				
			||||||
				r.Neg(&imbalance)
 | 
									r.Neg(&imbalance)
 | 
				
			||||||
				security, err := GetSecurity(tx, imbalanced_security, user.UserId)
 | 
									security, err := tx.GetSecurity(imbalanced_security, user.UserId)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					log.Print(err)
 | 
										log.Print(err)
 | 
				
			||||||
					return NewError(999 /*Internal Error*/)
 | 
										return NewError(999 /*Internal Error*/)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -97,7 +97,7 @@ func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
 | 
					func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
 | 
				
			||||||
	security, err := GetSecurity(context.Tx, securityid, user.UserId)
 | 
						security, err := context.Tx.GetSecurity(securityid, user.UserId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return NewError(3 /*Invalid Request*/)
 | 
							return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -112,7 +112,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
 | 
				
			|||||||
		if price.SecurityId != security.SecurityId {
 | 
							if price.SecurityId != security.SecurityId {
 | 
				
			||||||
			return NewError(3 /*Invalid Request*/)
 | 
								return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId)
 | 
							_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return NewError(3 /*Invalid Request*/)
 | 
								return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -161,11 +161,11 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
 | 
				
			|||||||
				return NewError(3 /*Invalid Request*/)
 | 
									return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			_, err = GetSecurity(context.Tx, price.SecurityId, user.UserId)
 | 
								_, err = context.Tx.GetSecurity(price.SecurityId, user.UserId)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return NewError(3 /*Invalid Request*/)
 | 
									return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId)
 | 
								_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return NewError(3 /*Invalid Request*/)
 | 
									return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,7 +4,6 @@ package handlers
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
						"github.com/aclindsa/moneygo/internal/models"
 | 
				
			||||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
						"github.com/aclindsa/moneygo/internal/store/db"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
@@ -51,90 +50,18 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) {
 | 
					 | 
				
			||||||
	var s models.Security
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &s, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) {
 | 
					 | 
				
			||||||
	var securities []*models.Security
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &securities, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func InsertSecurity(tx *db.Tx, s *models.Security) error {
 | 
					 | 
				
			||||||
	err := tx.Insert(s)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) {
 | 
					func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) {
 | 
				
			||||||
	user, err := GetUser(tx, s.UserId)
 | 
						user, err := tx.GetUser(s.UserId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency {
 | 
						} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency {
 | 
				
			||||||
		return errors.New("Cannot change security which is user's default currency to be non-currency")
 | 
							return errors.New("Cannot change security which is user's default currency to be non-currency")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	count, err := tx.Update(s)
 | 
						err = tx.UpdateSecurity(s)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if count > 1 {
 | 
					 | 
				
			||||||
		return fmt.Errorf("Updated %d securities (expected 1)", count)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type SecurityInUseError struct {
 | 
					 | 
				
			||||||
	message string
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (e SecurityInUseError) Error() string {
 | 
					 | 
				
			||||||
	return e.message
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func DeleteSecurity(tx *db.Tx, s *models.Security) error {
 | 
					 | 
				
			||||||
	// First, ensure no accounts are using this security
 | 
					 | 
				
			||||||
	accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if accounts != 0 {
 | 
					 | 
				
			||||||
		return SecurityInUseError{"One or more accounts still use this security"}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	user, err := GetUser(tx, s.UserId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	} else if user.DefaultCurrency == s.SecurityId {
 | 
					 | 
				
			||||||
		return SecurityInUseError{"Cannot delete security which is user's default currency"}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Remove all prices involving this security (either of this security, or
 | 
					 | 
				
			||||||
	// using it as a currency)
 | 
					 | 
				
			||||||
	_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	count, err := tx.Delete(s)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if count != 1 {
 | 
					 | 
				
			||||||
		return errors.New("Deleted more than one security")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -143,16 +70,14 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
 | 
				
			|||||||
	security.UserId = userid
 | 
						security.UserId = userid
 | 
				
			||||||
	if len(security.AlternateId) == 0 {
 | 
						if len(security.AlternateId) == 0 {
 | 
				
			||||||
		// Always create a new local security if we can't match on the AlternateId
 | 
							// Always create a new local security if we can't match on the AlternateId
 | 
				
			||||||
		err := InsertSecurity(tx, security)
 | 
							err := tx.InsertSecurity(security)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return security, nil
 | 
							return security, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var securities []*models.Security
 | 
						securities, err := tx.FindMatchingSecurities(userid, security)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -160,7 +85,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
 | 
				
			|||||||
	// First try to find a case insensitive match on the name or symbol
 | 
						// First try to find a case insensitive match on the name or symbol
 | 
				
			||||||
	upperName := strings.ToUpper(security.Name)
 | 
						upperName := strings.ToUpper(security.Name)
 | 
				
			||||||
	upperSymbol := strings.ToUpper(security.Symbol)
 | 
						upperSymbol := strings.ToUpper(security.Symbol)
 | 
				
			||||||
	for _, s := range securities {
 | 
						for _, s := range *securities {
 | 
				
			||||||
		if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
 | 
							if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
 | 
				
			||||||
			(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
 | 
								(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
 | 
				
			||||||
			return s, nil
 | 
								return s, nil
 | 
				
			||||||
@@ -169,7 +94,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
 | 
				
			|||||||
	//		if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
 | 
						//		if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Try to find a partial string match on the name or symbol
 | 
						// Try to find a partial string match on the name or symbol
 | 
				
			||||||
	for _, s := range securities {
 | 
						for _, s := range *securities {
 | 
				
			||||||
		sUpperName := strings.ToUpper(s.Name)
 | 
							sUpperName := strings.ToUpper(s.Name)
 | 
				
			||||||
		sUpperSymbol := strings.ToUpper(s.Symbol)
 | 
							sUpperSymbol := strings.ToUpper(s.Symbol)
 | 
				
			||||||
		if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
 | 
							if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
 | 
				
			||||||
@@ -179,12 +104,12 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Give up and return the first security in the list
 | 
						// Give up and return the first security in the list
 | 
				
			||||||
	if len(securities) > 0 {
 | 
						if len(*securities) > 0 {
 | 
				
			||||||
		return securities[0], nil
 | 
							return (*securities)[0], nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If there wasn't even one security in the list, make a new one
 | 
						// If there wasn't even one security in the list, make a new one
 | 
				
			||||||
	err = InsertSecurity(tx, security)
 | 
						err = tx.InsertSecurity(security)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -217,7 +142,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
				
			|||||||
		security.SecurityId = -1
 | 
							security.SecurityId = -1
 | 
				
			||||||
		security.UserId = user.UserId
 | 
							security.UserId = user.UserId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		err = InsertSecurity(context.Tx, &security)
 | 
							err = context.Tx.InsertSecurity(&security)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Print(err)
 | 
								log.Print(err)
 | 
				
			||||||
			return NewError(999 /*Internal Error*/)
 | 
								return NewError(999 /*Internal Error*/)
 | 
				
			||||||
@@ -229,7 +154,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
				
			|||||||
			//Return all securities
 | 
								//Return all securities
 | 
				
			||||||
			var sl models.SecurityList
 | 
								var sl models.SecurityList
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			securities, err := GetSecurities(context.Tx, user.UserId)
 | 
								securities, err := context.Tx.GetSecurities(user.UserId)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				log.Print(err)
 | 
									log.Print(err)
 | 
				
			||||||
				return NewError(999 /*Internal Error*/)
 | 
									return NewError(999 /*Internal Error*/)
 | 
				
			||||||
@@ -250,7 +175,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
				
			|||||||
				return PriceHandler(r, context, user, securityid)
 | 
									return PriceHandler(r, context, user, securityid)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			security, err := GetSecurity(context.Tx, securityid, user.UserId)
 | 
								security, err := context.Tx.GetSecurity(securityid, user.UserId)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return NewError(3 /*Invalid Request*/)
 | 
									return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -284,13 +209,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			return &security
 | 
								return &security
 | 
				
			||||||
		} else if r.Method == "DELETE" {
 | 
							} else if r.Method == "DELETE" {
 | 
				
			||||||
			security, err := GetSecurity(context.Tx, securityid, user.UserId)
 | 
								security, err := context.Tx.GetSecurity(securityid, user.UserId)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return NewError(3 /*Invalid Request*/)
 | 
									return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			err = DeleteSecurity(context.Tx, security)
 | 
								err = context.Tx.DeleteSecurity(security)
 | 
				
			||||||
			if _, ok := err.(SecurityInUseError); ok {
 | 
								if _, ok := err.(db.SecurityInUseError); ok {
 | 
				
			||||||
				return NewError(7 /*In Use Error*/)
 | 
									return NewError(7 /*In Use Error*/)
 | 
				
			||||||
			} else if err != nil {
 | 
								} else if err != nil {
 | 
				
			||||||
				log.Print(err)
 | 
									log.Print(err)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -27,7 +27,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
 | 
				
			|||||||
			return nil, errors.New("Couldn't find User in lua's Context")
 | 
								return nil, errors.New("Couldn't find User in lua's Context")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		securities, err := GetSecurities(tx, user.UserId)
 | 
							securities, err := tx.GetSecurities(user.UserId)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -85,12 +85,15 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
				
			|||||||
			return NewError(3 /*Invalid Request*/)
 | 
								return NewError(3 /*Invalid Request*/)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		dbuser, err := GetUserByUsername(context.Tx, user.Username)
 | 
							// Hash password before checking username to help mitigate timing
 | 
				
			||||||
 | 
							// attacks
 | 
				
			||||||
 | 
							user.HashPassword()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							dbuser, err := context.StoreTx.GetUserByUsername(user.Username)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return NewError(2 /*Unauthorized Access*/)
 | 
								return NewError(2 /*Unauthorized Access*/)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		user.HashPassword()
 | 
					 | 
				
			||||||
		if user.PasswordHash != dbuser.PasswordHash {
 | 
							if user.PasswordHash != dbuser.PasswordHash {
 | 
				
			||||||
			return NewError(2 /*Unauthorized Access*/)
 | 
								return NewError(2 /*Unauthorized Access*/)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -542,7 +542,7 @@ func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	atl.TotalTransactions = count
 | 
						atl.TotalTransactions = count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId)
 | 
						security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,9 +2,8 @@ package handlers
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
						"github.com/aclindsa/moneygo/internal/models"
 | 
				
			||||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
						"github.com/aclindsa/moneygo/internal/store"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -15,41 +14,21 @@ func (ueu UserExistsError) Error() string {
 | 
				
			|||||||
	return "User exists"
 | 
						return "User exists"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetUser(tx *db.Tx, userid int64) (*models.User, error) {
 | 
					func InsertUser(tx store.Tx, u *models.User) error {
 | 
				
			||||||
	var u models.User
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &u, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) {
 | 
					 | 
				
			||||||
	var u models.User
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &u, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func InsertUser(tx *db.Tx, u *models.User) error {
 | 
					 | 
				
			||||||
	security_template := FindCurrencyTemplate(u.DefaultCurrency)
 | 
						security_template := FindCurrencyTemplate(u.DefaultCurrency)
 | 
				
			||||||
	if security_template == nil {
 | 
						if security_template == nil {
 | 
				
			||||||
		return errors.New("Invalid ISO4217 Default Currency")
 | 
							return errors.New("Invalid ISO4217 Default Currency")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username)
 | 
						exists, err := tx.UsernameExists(u.Username)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if existing > 0 {
 | 
						if exists {
 | 
				
			||||||
		return UserExistsError{}
 | 
							return UserExistsError{}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = tx.Insert(u)
 | 
						err = tx.InsertUser(u)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -59,33 +38,31 @@ func InsertUser(tx *db.Tx, u *models.User) error {
 | 
				
			|||||||
	security = *security_template
 | 
						security = *security_template
 | 
				
			||||||
	security.UserId = u.UserId
 | 
						security.UserId = u.UserId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = InsertSecurity(tx, &security)
 | 
						err = tx.InsertSecurity(&security)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Update the user's DefaultCurrency to our new SecurityId
 | 
						// Update the user's DefaultCurrency to our new SecurityId
 | 
				
			||||||
	u.DefaultCurrency = security.SecurityId
 | 
						u.DefaultCurrency = security.SecurityId
 | 
				
			||||||
	count, err := tx.Update(u)
 | 
						err = tx.UpdateUser(u)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	} else if count != 1 {
 | 
					 | 
				
			||||||
		return errors.New("Would have updated more than one user")
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) {
 | 
					func GetUserFromSession(tx store.Tx, r *http.Request) (*models.User, error) {
 | 
				
			||||||
	s, err := GetSession(tx, r)
 | 
						s, err := GetSession(tx, r)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return GetUser(tx, s.UserId)
 | 
						return tx.GetUser(s.UserId)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func UpdateUser(tx *db.Tx, u *models.User) error {
 | 
					func UpdateUser(tx store.Tx, u *models.User) error {
 | 
				
			||||||
	security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
 | 
						security, err := tx.GetSecurity(u.DefaultCurrency, u.UserId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
 | 
						} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
 | 
				
			||||||
@@ -94,49 +71,7 @@ func UpdateUser(tx *db.Tx, u *models.User) error {
 | 
				
			|||||||
		return errors.New("New DefaultCurrency security is not a currency")
 | 
							return errors.New("New DefaultCurrency security is not a currency")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	count, err := tx.Update(u)
 | 
						err = tx.UpdateUser(u)
 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	} else if count != 1 {
 | 
					 | 
				
			||||||
		return errors.New("Would have updated more than one user")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func DeleteUser(tx *db.Tx, u *models.User) error {
 | 
					 | 
				
			||||||
	count, err := tx.Delete(u)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if count != 1 {
 | 
					 | 
				
			||||||
		return fmt.Errorf("No user to delete")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", u.UserId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", u.UserId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -205,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			return user
 | 
								return user
 | 
				
			||||||
		} else if r.Method == "DELETE" {
 | 
							} else if r.Method == "DELETE" {
 | 
				
			||||||
			err := DeleteUser(context.Tx, user)
 | 
								err := context.StoreTx.DeleteUser(user)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				log.Print(err)
 | 
									log.Print(err)
 | 
				
			||||||
				return NewError(999 /*Internal Error*/)
 | 
									return NewError(999 /*Internal Error*/)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										95
									
								
								internal/store/db/securities.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								internal/store/db/securities.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,95 @@
 | 
				
			|||||||
 | 
					package db
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/aclindsa/moneygo/internal/models"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SecurityInUseError struct {
 | 
				
			||||||
 | 
						message string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (e SecurityInUseError) Error() string {
 | 
				
			||||||
 | 
						return e.message
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) {
 | 
				
			||||||
 | 
						var s models.Security
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &s, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) {
 | 
				
			||||||
 | 
						var securities []*models.Security
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &securities, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) {
 | 
				
			||||||
 | 
						var securities []*models.Security
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &securities, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) InsertSecurity(s *models.Security) error {
 | 
				
			||||||
 | 
						err := tx.Insert(s)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) UpdateSecurity(security *models.Security) error {
 | 
				
			||||||
 | 
						count, err := tx.Update(security)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count != 1 {
 | 
				
			||||||
 | 
							return fmt.Errorf("Expected to update 1 security, was going to update %d", count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) DeleteSecurity(s *models.Security) error {
 | 
				
			||||||
 | 
						// First, ensure no accounts are using this security
 | 
				
			||||||
 | 
						accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if accounts != 0 {
 | 
				
			||||||
 | 
							return SecurityInUseError{"One or more accounts still use this security"}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						user, err := tx.GetUser(s.UserId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						} else if user.DefaultCurrency == s.SecurityId {
 | 
				
			||||||
 | 
							return SecurityInUseError{"Cannot delete security which is user's default currency"}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Remove all prices involving this security (either of this security, or
 | 
				
			||||||
 | 
						// using it as a currency)
 | 
				
			||||||
 | 
						_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						count, err := tx.Delete(s)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count != 1 {
 | 
				
			||||||
 | 
							return fmt.Errorf("Expected to delete 1 security, was going to delete %d", count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -31,6 +31,12 @@ func (tx *Tx) SessionExists(secret string) (bool, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (tx *Tx) DeleteSession(session *models.Session) error {
 | 
					func (tx *Tx) DeleteSession(session *models.Session) error {
 | 
				
			||||||
	_, err := tx.Delete(session)
 | 
						count, err := tx.Delete(session)
 | 
				
			||||||
	return err
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count != 1 {
 | 
				
			||||||
 | 
							return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										86
									
								
								internal/store/db/users.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								internal/store/db/users.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					package db
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/aclindsa/moneygo/internal/models"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) UsernameExists(username string) (bool, error) {
 | 
				
			||||||
 | 
						existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", username)
 | 
				
			||||||
 | 
						return existing != 0, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) InsertUser(user *models.User) error {
 | 
				
			||||||
 | 
						return tx.Insert(user)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) GetUser(userid int64) (*models.User, error) {
 | 
				
			||||||
 | 
						var u models.User
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &u, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) GetUserByUsername(username string) (*models.User, error) {
 | 
				
			||||||
 | 
						var u models.User
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &u, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) UpdateUser(user *models.User) error {
 | 
				
			||||||
 | 
						count, err := tx.Update(user)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count != 1 {
 | 
				
			||||||
 | 
							return fmt.Errorf("Expected to update 1 user, was going to update %d", count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tx *Tx) DeleteUser(user *models.User) error {
 | 
				
			||||||
 | 
						count, err := tx.Delete(user)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if count != 1 {
 | 
				
			||||||
 | 
							return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", user.UserId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", user.UserId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", user.UserId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", user.UserId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", user.UserId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", user.UserId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", user.UserId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -5,17 +5,37 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type SessionStore interface {
 | 
					type SessionStore interface {
 | 
				
			||||||
 | 
						SessionExists(secret string) (bool, error)
 | 
				
			||||||
	InsertSession(session *models.Session) error
 | 
						InsertSession(session *models.Session) error
 | 
				
			||||||
	GetSession(secret string) (*models.Session, error)
 | 
						GetSession(secret string) (*models.Session, error)
 | 
				
			||||||
	SessionExists(secret string) (bool, error)
 | 
					 | 
				
			||||||
	DeleteSession(session *models.Session) error
 | 
						DeleteSession(session *models.Session) error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type UserStore interface {
 | 
				
			||||||
 | 
						UsernameExists(username string) (bool, error)
 | 
				
			||||||
 | 
						InsertUser(user *models.User) error
 | 
				
			||||||
 | 
						GetUser(userid int64) (*models.User, error)
 | 
				
			||||||
 | 
						GetUserByUsername(username string) (*models.User, error)
 | 
				
			||||||
 | 
						UpdateUser(user *models.User) error
 | 
				
			||||||
 | 
						DeleteUser(user *models.User) error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SecurityStore interface {
 | 
				
			||||||
 | 
						InsertSecurity(security *models.Security) error
 | 
				
			||||||
 | 
						GetSecurity(securityid int64, userid int64) (*models.Security, error)
 | 
				
			||||||
 | 
						GetSecurities(userid int64) (*[]*models.Security, error)
 | 
				
			||||||
 | 
						FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error)
 | 
				
			||||||
 | 
						UpdateSecurity(security *models.Security) error
 | 
				
			||||||
 | 
						DeleteSecurity(security *models.Security) error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Tx interface {
 | 
					type Tx interface {
 | 
				
			||||||
	Commit() error
 | 
						Commit() error
 | 
				
			||||||
	Rollback() error
 | 
						Rollback() error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	SessionStore
 | 
						SessionStore
 | 
				
			||||||
 | 
						UserStore
 | 
				
			||||||
 | 
						SecurityStore
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Store interface {
 | 
					type Store interface {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user