1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2024-12-26 23:42:29 -05:00

Remove duplicate *Tx versions of database access methods

Simplify naming to remove "Tx" now that all handlers only have access to
transactions anyway, and always use "tx" as the name of the variable
representing the SQL transactions (to make it less likely to cause
confusion with monetary transactions).
This commit is contained in:
Aaron Lindsay 2017-10-14 19:41:13 -04:00
parent 4e53a5e59c
commit 2ff1f47432
8 changed files with 67 additions and 147 deletions

View File

@ -3,7 +3,6 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"gopkg.in/gorp.v1"
"log" "log"
"net/http" "net/http"
"regexp" "regexp"
@ -139,17 +138,6 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*Account, error) {
return &a, nil return &a, nil
} }
func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) (*Account, error) {
var a Account
err := transaction.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
if err != nil {
return nil, err
}
return &a, nil
}
func GetAccounts(tx *Tx, userid int64) (*[]Account, error) { func GetAccounts(tx *Tx, userid int64) (*[]Account, error) {
var accounts []Account var accounts []Account
@ -162,12 +150,12 @@ func GetAccounts(tx *Tx, userid int64) (*[]Account, error) {
// Get (and attempt to create if it doesn't exist). Matches on UserId, // Get (and attempt to create if it doesn't exist). Matches on UserId,
// SecurityId, Type, Name, and ParentAccountId // SecurityId, Type, Name, and ParentAccountId
func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, error) { func GetCreateAccount(tx *Tx, a Account) (*Account, error) {
var accounts []Account var accounts []Account
var account Account var account Account
// Try to find the top-level trading account // Try to find the top-level trading account
_, err := transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -180,7 +168,7 @@ func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, err
account.Name = a.Name account.Name = a.Name
account.ParentAccountId = a.ParentAccountId account.ParentAccountId = a.ParentAccountId
err = transaction.Insert(&account) err = tx.Insert(&account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -190,11 +178,11 @@ func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, err
// Get (and attempt to create if it doesn't exist) the security/currency // Get (and attempt to create if it doesn't exist) the security/currency
// trading account for the supplied security/currency // trading account for the supplied security/currency
func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) {
var tradingAccount Account var tradingAccount Account
var account Account var account Account
user, err := GetUserTx(transaction, userid) user, err := GetUser(tx, userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -206,12 +194,12 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i
tradingAccount.ParentAccountId = -1 tradingAccount.ParentAccountId = -1
// Find/create the top-level trading account // Find/create the top-level trading account
ta, err := GetCreateAccountTx(transaction, tradingAccount) ta, err := GetCreateAccount(tx, tradingAccount)
if err != nil { if err != nil {
return nil, err return nil, err
} }
security, err := GetSecurityTx(transaction, securityid, userid) security, err := GetSecurity(tx, securityid, userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -222,7 +210,7 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i
account.SecurityId = securityid account.SecurityId = securityid
account.Type = Trading account.Type = Trading
a, err := GetCreateAccountTx(transaction, account) a, err := GetCreateAccount(tx, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -232,14 +220,14 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i
// Get (and attempt to create if it doesn't exist) the security/currency // Get (and attempt to create if it doesn't exist) the security/currency
// imbalance account for the supplied security/currency // imbalance account for the supplied security/currency
func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, error) {
var imbalanceAccount Account var imbalanceAccount Account
var account Account var account Account
xxxtemplate := FindSecurityTemplate("XXX", Currency) xxxtemplate := FindSecurityTemplate("XXX", Currency)
if xxxtemplate == nil { if xxxtemplate == nil {
return nil, errors.New("Couldn't find XXX security template") return nil, errors.New("Couldn't find XXX security template")
} }
xxxsecurity, err := ImportGetCreateSecurity(transaction, userid, xxxtemplate) xxxsecurity, err := ImportGetCreateSecurity(tx, userid, xxxtemplate)
if err != nil { if err != nil {
return nil, errors.New("Couldn't create XXX security") return nil, errors.New("Couldn't create XXX security")
} }
@ -251,12 +239,12 @@ func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid
imbalanceAccount.Type = Bank imbalanceAccount.Type = Bank
// Find/create the top-level trading account // Find/create the top-level trading account
ia, err := GetCreateAccountTx(transaction, imbalanceAccount) ia, err := GetCreateAccount(tx, imbalanceAccount)
if err != nil { if err != nil {
return nil, err return nil, err
} }
security, err := GetSecurityTx(transaction, securityid, userid) security, err := GetSecurity(tx, securityid, userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -267,7 +255,7 @@ func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid
account.SecurityId = securityid account.SecurityId = securityid
account.Type = Bank account.Type = Bank
a, err := GetCreateAccountTx(transaction, account) a, err := GetCreateAccount(tx, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -330,7 +318,7 @@ func insertUpdateAccount(tx *Tx, a *Account, insert bool) error {
return err return err
} }
} else { } else {
oldacct, err := GetAccountTx(tx, a.AccountId, a.UserId) oldacct, err := GetAccount(tx, a.AccountId, a.UserId)
if err != nil { if err != nil {
return err return err
} }

View File

@ -407,7 +407,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
account.ParentAccountId = accountMap[account.ParentAccountId] account.ParentAccountId = accountMap[account.ParentAccountId]
} }
account.SecurityId = securityMap[account.SecurityId] account.SecurityId = securityMap[account.SecurityId]
a, err := GetCreateAccountTx(tx, account) a, err := GetCreateAccount(tx, account)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -436,7 +436,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
} }
split.AccountId = acctId split.AccountId = acctId
exists, err := split.AlreadyImportedTx(tx) exists, err := split.AlreadyImported(tx)
if err != nil { if err != nil {
log.Print("Error checking if split was already imported:", err) log.Print("Error checking if split was already imported:", err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -445,7 +445,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
} }
} }
if !already_imported { if !already_imported {
err := InsertTransactionTx(tx, &transaction, user) err := InsertTransaction(tx, &transaction, user)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -37,7 +37,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW
} }
// Return Account with this Id // Return Account with this Id
account, err := GetAccountTx(tx, accountid, user.UserId) account, err := GetAccount(tx, accountid, user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
@ -117,7 +117,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW
SecurityId: sec.SecurityId, SecurityId: sec.SecurityId,
Type: account.Type, Type: account.Type,
} }
subaccount, err := GetCreateAccountTx(tx, *subaccount) subaccount, err := GetCreateAccount(tx, *subaccount)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -137,7 +137,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW
} }
} }
imbalances, err := transaction.GetImbalancesTx(tx) imbalances, err := transaction.GetImbalances(tx)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -157,7 +157,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW
split := new(Split) split := new(Split)
r := new(big.Rat) r := new(big.Rat)
r.Neg(&imbalance) r.Neg(&imbalance)
security, err := GetSecurityTx(tx, imbalanced_security, user.UserId) security, err := GetSecurity(tx, imbalanced_security, user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -185,7 +185,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW
split.SecurityId = -1 split.SecurityId = -1
} }
exists, err := split.AlreadyImportedTx(tx) exists, err := split.AlreadyImported(tx)
if err != nil { if err != nil {
log.Print("Error checking if split was already imported:", err) log.Print("Error checking if split was already imported:", err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -200,7 +200,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW
} }
for _, transaction := range transactions { for _, transaction := range transactions {
err := InsertTransactionTx(tx, &transaction, user) err := InsertTransaction(tx, &transaction, user)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"gopkg.in/gorp.v1"
"time" "time"
) )
@ -14,18 +13,18 @@ type Price struct {
RemoteId string // unique ID from source, for detecting duplicates RemoteId string // unique ID from source, for detecting duplicates
} }
func InsertPriceTx(transaction *gorp.Transaction, p *Price) error { func InsertPrice(tx *Tx, p *Price) error {
err := transaction.Insert(p) err := tx.Insert(p)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { func CreatePriceIfNotExist(tx *Tx, price *Price) error {
if len(price.RemoteId) == 0 { if len(price.RemoteId) == 0 {
// Always create a new price if we can't match on the RemoteId // Always create a new price if we can't match on the RemoteId
err := InsertPriceTx(transaction, price) err := InsertPrice(tx, price)
if err != nil { if err != nil {
return err return err
} }
@ -34,7 +33,7 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error {
var prices []*Price var prices []*Price
_, err := transaction.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
if err != nil { if err != nil {
return err return err
} }
@ -43,7 +42,7 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error {
return nil // price already exists return nil // price already exists
} }
err = InsertPriceTx(transaction, price) err = InsertPrice(tx, price)
if err != nil { if err != nil {
return err return err
} }
@ -51,9 +50,9 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error {
} }
// Return the latest price for security in currency units before date // Return the latest price for security in currency units before date
func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) {
var p Price var p Price
err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,9 +60,9 @@ func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security,
} }
// Return the earliest price for security in currency units after date // Return the earliest price for security in currency units after date
func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { func GetEarliestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) {
var p Price var p Price
err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,9 +70,9 @@ func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Securit
} }
// Return the price for security in currency closest to date // Return the price for security in currency closest to date
func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) {
earliest, _ := GetEarliestPrice(transaction, security, currency, date) earliest, _ := GetEarliestPrice(tx, security, currency, date)
latest, err := GetLatestPrice(transaction, security, currency, date) latest, err := GetLatestPrice(tx, security, currency, date)
// Return early if either earliest or latest are invalid // Return early if either earliest or latest are invalid
if earliest == nil { if earliest == nil {
@ -90,7 +89,3 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi
return earliest, nil return earliest, nil
} }
} }
func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) {
return GetClosestPriceTx(tx, security, currency, date)
}

View File

@ -5,7 +5,6 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"gopkg.in/gorp.v1"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
@ -113,16 +112,6 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) {
return &s, nil return &s, nil
} }
func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64) (*Security, error) {
var s Security
err := transaction.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
if err != nil {
return nil, err
}
return &s, nil
}
func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) {
var securities []*Security var securities []*Security
@ -141,16 +130,8 @@ func InsertSecurity(tx *Tx, s *Security) error {
return nil return nil
} }
func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error {
err := transaction.Insert(s)
if err != nil {
return err
}
return nil
}
func UpdateSecurity(tx *Tx, s *Security) (err error) { func UpdateSecurity(tx *Tx, s *Security) (err error) {
user, err := GetUserTx(tx, s.UserId) user, err := GetUser(tx, s.UserId)
if err != nil { if err != nil {
return return
} else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { } else if user.DefaultCurrency == s.SecurityId && s.Type != Currency {
@ -184,7 +165,7 @@ func DeleteSecurity(tx *Tx, s *Security) error {
return SecurityInUseError{"One or more accounts still use this security"} return SecurityInUseError{"One or more accounts still use this security"}
} }
user, err := GetUserTx(tx, s.UserId) user, err := GetUser(tx, s.UserId)
if err != nil { if err != nil {
return err return err
} else if user.DefaultCurrency == s.SecurityId { } else if user.DefaultCurrency == s.SecurityId {
@ -209,11 +190,11 @@ func DeleteSecurity(tx *Tx, s *Security) error {
return nil return nil
} }
func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, security *Security) (*Security, error) { func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Security, error) {
security.UserId = userid security.UserId = userid
if len(security.AlternateId) == 0 { if len(security.AlternateId) == 0 {
// Always create a new local security if we can't match on the AlternateId // Always create a new local security if we can't match on the AlternateId
err := InsertSecurityTx(transaction, security) err := InsertSecurity(tx, security)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -222,7 +203,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
var securities []*Security var securities []*Security
_, err := transaction.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision) _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -254,7 +235,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
} }
// If there wasn't even one security in the list, make a new one // If there wasn't even one security in the list, make a new one
err = InsertSecurityTx(transaction, security) err = InsertSecurity(tx, security)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -263,7 +244,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
} }
func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
user, err := GetUserFromSessionTx(tx, r) user, err := GetUserFromSession(tx, r)
if err != nil { if err != nil {
return NewError(1 /*Not Signed In*/) return NewError(1 /*Not Signed In*/)
} }
@ -282,7 +263,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
security.SecurityId = -1 security.SecurityId = -1
security.UserId = user.UserId security.UserId = user.UserId
err = InsertSecurityTx(tx, &security) err = InsertSecurity(tx, &security)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -306,7 +287,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
sl.Securities = securities sl.Securities = securities
return &sl return &sl
} else { } else {
security, err := GetSecurityTx(tx, securityid, user.UserId) security, err := GetSecurity(tx, securityid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -339,7 +320,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
return &security return &security
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
security, err := GetSecurityTx(tx, securityid, user.UserId) security, err := GetSecurity(tx, securityid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }

View File

@ -44,24 +44,8 @@ func GetSession(tx *Tx, r *http.Request) (*Session, error) {
return &s, nil return &s, nil
} }
func GetSessionTx(tx *Tx, r *http.Request) (*Session, error) {
var s Session
cookie, err := r.Cookie("moneygo-session")
if err != nil {
return nil, fmt.Errorf("moneygo-session cookie not set")
}
s.SessionSecret = cookie.Value
err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
if err != nil {
return nil, err
}
return &s, nil
}
func DeleteSessionIfExists(tx *Tx, r *http.Request) error { func DeleteSessionIfExists(tx *Tx, r *http.Request) error {
session, err := GetSessionTx(tx, r) session, err := GetSession(tx, r)
if err == nil { if err == nil {
_, err := tx.Delete(session) _, err := tx.Delete(session)
if err != nil { if err != nil {
@ -153,7 +137,7 @@ func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
} }
return sessionwriter return sessionwriter
} else if r.Method == "GET" { } else if r.Method == "GET" {
s, err := GetSessionTx(tx, r) s, err := GetSession(tx, r)
if err != nil { if err != nil {
return NewError(1 /*Not Signed In*/) return NewError(1 /*Not Signed In*/)
} }

View File

@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"gopkg.in/gorp.v1"
"log" "log"
"math/big" "math/big"
"net/http" "net/http"
@ -78,8 +77,8 @@ func (s *Split) Valid() bool {
return err == nil return err == nil
} }
func (s *Split) AlreadyImportedTx(transaction *gorp.Transaction) (bool, error) { func (s *Split) AlreadyImported(tx *Tx) (bool, error) {
count, err := transaction.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
return count == 1, err return count == 1, err
} }
@ -134,7 +133,7 @@ func (t *Transaction) Valid() bool {
// Return a map of security ID's to big.Rat's containing the amount that // Return a map of security ID's to big.Rat's containing the amount that
// security is imbalanced by // security is imbalanced by
func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]big.Rat, error) { func (t *Transaction) GetImbalances(tx *Tx) (map[int64]big.Rat, error) {
sums := make(map[int64]big.Rat) sums := make(map[int64]big.Rat)
if !t.Valid() { if !t.Valid() {
@ -146,7 +145,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
if t.Splits[i].AccountId != -1 { if t.Splits[i].AccountId != -1 {
var err error var err error
var account *Account var account *Account
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId) account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -162,10 +161,10 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
// Returns true if all securities contained in this transaction are balanced, // Returns true if all securities contained in this transaction are balanced,
// false otherwise // false otherwise
func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) { func (t *Transaction) Balanced(tx *Tx) (bool, error) {
var zero big.Rat var zero big.Rat
sums, err := t.GetImbalancesTx(transaction) sums, err := t.GetImbalances(tx)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -214,7 +213,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) {
func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error { func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error {
for i := range accountids { for i := range accountids {
account, err := GetAccountTx(tx, accountids[i], user.UserId) account, err := GetAccount(tx, accountids[i], user.UserId)
if err != nil { if err != nil {
return err return err
} }
@ -236,7 +235,7 @@ func (ame AccountMissingError) Error() string {
return "Account missing" return "Account missing"
} }
func InsertTransactionTx(tx *Tx, t *Transaction, user *User) error { func InsertTransaction(tx *Tx, t *Transaction, user *User) error {
// Map of any accounts with transaction splits being added // Map of any accounts with transaction splits being added
a_map := make(map[int64]bool) a_map := make(map[int64]bool)
for i := range t.Splits { for i := range t.Splits {
@ -286,16 +285,7 @@ func InsertTransactionTx(tx *Tx, t *Transaction, user *User) error {
return nil return nil
} }
func InsertTransaction(tx *Tx, t *Transaction, user *User) error { func UpdateTransaction(tx *Tx, t *Transaction, user *User) error {
err := InsertTransactionTx(tx, t, user)
if err != nil {
return err
}
return nil
}
func UpdateTransactionTx(tx *Tx, t *Transaction, user *User) error {
var existing_splits []*Split var existing_splits []*Split
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
@ -431,13 +421,13 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
for i := range transaction.Splits { for i := range transaction.Splits {
transaction.Splits[i].SplitId = -1 transaction.Splits[i].SplitId = -1
_, err := GetAccountTx(tx, transaction.Splits[i].AccountId, user.UserId) _, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
} }
err = InsertTransactionTx(tx, &transaction, user) err = InsertTransaction(tx, &transaction, user)
if err != nil { if err != nil {
if _, ok := err.(AccountMissingError); ok { if _, ok := err.(AccountMissingError); ok {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
@ -497,13 +487,13 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
} }
for i := range transaction.Splits { for i := range transaction.Splits {
_, err := GetAccountTx(tx, transaction.Splits[i].AccountId, user.UserId) _, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
} }
err = UpdateTransactionTx(tx, &transaction, user) err = UpdateTransaction(tx, &transaction, user)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -533,10 +523,10 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) { func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Transaction) (*big.Rat, error) {
var pageDifference, tmp big.Rat var pageDifference, tmp big.Rat
for i := range transactions { for i := range transactions {
_, err := transaction.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -649,7 +639,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
} }
account, err := GetAccountTx(tx, accountid, user.UserId) account, err := GetAccount(tx, accountid, user.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -673,7 +663,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
} }
atl.TotalTransactions = count atl.TotalTransactions = count
security, err := GetSecurityTx(tx, atl.Account.SecurityId, user.UserId) security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -56,16 +56,6 @@ func GetUser(tx *Tx, userid int64) (*User, error) {
return &u, nil return &u, nil
} }
func GetUserTx(tx *Tx, userid int64) (*User, error) {
var u User
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
if err != nil {
return nil, err
}
return &u, nil
}
func GetUserByUsername(tx *Tx, username string) (*User, error) { func GetUserByUsername(tx *Tx, username string) (*User, error) {
var u User var u User
@ -100,7 +90,7 @@ func InsertUser(tx *Tx, u *User) error {
security = *security_template security = *security_template
security.UserId = u.UserId security.UserId = u.UserId
err = InsertSecurityTx(tx, &security) err = InsertSecurity(tx, &security)
if err != nil { if err != nil {
return err return err
} }
@ -125,16 +115,8 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) {
return GetUser(tx, s.UserId) return GetUser(tx, s.UserId)
} }
func GetUserFromSessionTx(tx *Tx, r *http.Request) (*User, error) {
s, err := GetSessionTx(tx, r)
if err != nil {
return nil, err
}
return GetUserTx(tx, s.UserId)
}
func UpdateUser(tx *Tx, u *User) error { func UpdateUser(tx *Tx, u *User) error {
security, err := GetSecurityTx(tx, u.DefaultCurrency, u.UserId) security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
if err != nil { if err != nil {
return err return err
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {