1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2024-12-27 07:52:28 -05:00

Pass DB as a closure instead of a global variable

This is part of an ongoing attempt to restructure the code to make it
more 'testable'.
This commit is contained in:
Aaron Lindsay 2017-10-04 08:05:51 -04:00
parent 9abafa50b2
commit 156b9aaf0c
13 changed files with 253 additions and 208 deletions

View File

@ -124,10 +124,10 @@ func (al *AccountList) Write(w http.ResponseWriter) error {
return enc.Encode(al) return enc.Encode(al)
} }
func GetAccount(accountid int64, userid int64) (*Account, error) { func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) {
var a Account var a Account
err := DB.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,10 +145,10 @@ func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64)
return &a, nil return &a, nil
} }
func GetAccounts(userid int64) (*[]Account, error) { func GetAccounts(db *DB, userid int64) (*[]Account, error) {
var accounts []Account var accounts []Account
_, err := DB.Select(&accounts, "SELECT * from accounts where UserId=?", userid) _, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -276,8 +276,8 @@ func (pame ParentAccountMissingError) Error() string {
return "Parent account missing" return "Parent account missing"
} }
func insertUpdateAccount(a *Account, insert bool) error { func insertUpdateAccount(db *DB, a *Account, insert bool) error {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -329,16 +329,16 @@ func insertUpdateAccount(a *Account, insert bool) error {
return nil return nil
} }
func InsertAccount(a *Account) error { func InsertAccount(db *DB, a *Account) error {
return insertUpdateAccount(a, true) return insertUpdateAccount(db, a, true)
} }
func UpdateAccount(a *Account) error { func UpdateAccount(db *DB, a *Account) error {
return insertUpdateAccount(a, false) return insertUpdateAccount(db, a, false)
} }
func DeleteAccount(a *Account) error { func DeleteAccount(db *DB, a *Account) error {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -385,8 +385,8 @@ func DeleteAccount(a *Account) error {
return nil return nil
} }
func AccountHandler(w http.ResponseWriter, r *http.Request) { func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(r) user, err := GetUserFromSession(db, r)
if err != nil { if err != nil {
WriteError(w, 1 /*Not Signed In*/) WriteError(w, 1 /*Not Signed In*/)
return return
@ -405,7 +405,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
log.Print(err) log.Print(err)
return return
} }
AccountImportHandler(w, r, user, accountid, importtype) AccountImportHandler(db, w, r, user, accountid, importtype)
return return
} }
@ -425,7 +425,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
account.UserId = user.UserId account.UserId = user.UserId
account.AccountVersion = 0 account.AccountVersion = 0
security, err := GetSecurity(account.SecurityId, user.UserId) security, err := GetSecurity(db, account.SecurityId, user.UserId)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -436,7 +436,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
err = InsertAccount(&account) err = InsertAccount(db, &account)
if err != nil { if err != nil {
if _, ok := err.(ParentAccountMissingError); ok { if _, ok := err.(ParentAccountMissingError); ok {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
@ -461,7 +461,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
if err != nil || n != 1 { if err != nil || n != 1 {
//Return all Accounts //Return all Accounts
var al AccountList var al AccountList
accounts, err := GetAccounts(user.UserId) accounts, err := GetAccounts(db, user.UserId)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -478,12 +478,12 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
// if URL looks like /account/[0-9]+/transactions, use the account // if URL looks like /account/[0-9]+/transactions, use the account
// transaction handler // transaction handler
if accountTransactionsRE.MatchString(r.URL.Path) { if accountTransactionsRE.MatchString(r.URL.Path) {
AccountTransactionsHandler(w, r, user, accountid) AccountTransactionsHandler(db, w, r, user, accountid)
return return
} }
// Return Account with this Id // Return Account with this Id
account, err := GetAccount(accountid, user.UserId) account, err := GetAccount(db, accountid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
@ -517,7 +517,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
} }
account.UserId = user.UserId account.UserId = user.UserId
security, err := GetSecurity(account.SecurityId, user.UserId) security, err := GetSecurity(db, account.SecurityId, user.UserId)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -528,7 +528,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
err = UpdateAccount(&account) err = UpdateAccount(db, &account)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -542,13 +542,13 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
account, err := GetAccount(accountid, user.UserId) account, err := GetAccount(db, accountid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
err = DeleteAccount(account) err = DeleteAccount(db, account)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)

View File

@ -15,14 +15,19 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
ctx := L.Context() ctx := L.Context()
account_map, ok := ctx.Value(accountsContextKey).(map[int64]*Account) db, ok := ctx.Value(dbContextKey).(*DB)
if !ok {
return nil, errors.New("Couldn't find DB in lua's Context")
}
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
if !ok { if !ok {
user, ok := ctx.Value(userContextKey).(*User) user, ok := ctx.Value(userContextKey).(*User)
if !ok { if !ok {
return nil, errors.New("Couldn't find User in lua's Context") return nil, errors.New("Couldn't find User in lua's Context")
} }
accounts, err := GetAccounts(user.UserId) accounts, err := GetAccounts(db, user.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -144,6 +149,10 @@ func luaAccountBalance(L *lua.LState) int {
a := luaCheckAccount(L, 1) a := luaCheckAccount(L, 1)
ctx := L.Context() ctx := L.Context()
db, ok := ctx.Value(dbContextKey).(*DB)
if !ok {
panic("Couldn't find DB in lua's Context")
}
user, ok := ctx.Value(userContextKey).(*User) user, ok := ctx.Value(userContextKey).(*User)
if !ok { if !ok {
panic("Couldn't find User in lua's Context") panic("Couldn't find User in lua's Context")
@ -162,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int {
if date != nil { if date != nil {
end := luaWeakCheckTime(L, 3) end := luaWeakCheckTime(L, 3)
if end != nil { if end != nil {
rat, err = GetAccountBalanceDateRange(user, a.AccountId, date, end) rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end)
} else { } else {
rat, err = GetAccountBalanceDate(user, a.AccountId, date) rat, err = GetAccountBalanceDate(db, user, a.AccountId, date)
} }
} else { } else {
rat, err = GetAccountBalance(user, a.AccountId) rat, err = GetAccountBalance(db, user, a.AccountId)
} }
if err != nil { if err != nil {
panic("Failed to GetAccountBalance:" + err.Error()) panic("Failed to GetAccountBalance:" + err.Error())

14
db.go
View File

@ -2,19 +2,17 @@ package main
import ( import (
"database/sql" "database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"gopkg.in/gorp.v1" "gopkg.in/gorp.v1"
"log"
) )
var DB *gorp.DbMap func initDB(cfg *Config) (*gorp.DbMap, error) {
func initDB(cfg *Config) {
db, err := sql.Open(cfg.MoneyGo.DBType.String(), cfg.MoneyGo.DSN) db, err := sql.Open(cfg.MoneyGo.DBType.String(), cfg.MoneyGo.DSN)
if err != nil { if err != nil {
log.Fatal(err) return nil, err
} }
var dialect gorp.Dialect var dialect gorp.Dialect
@ -28,7 +26,7 @@ func initDB(cfg *Config) {
} else if cfg.MoneyGo.DBType == Postgres { } else if cfg.MoneyGo.DBType == Postgres {
dialect = gorp.PostgresDialect{} dialect = gorp.PostgresDialect{}
} else { } else {
log.Fatalf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String()) return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String())
} }
dbmap := &gorp.DbMap{Db: db, Dialect: dialect} dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
@ -43,8 +41,8 @@ func initDB(cfg *Config) {
err = dbmap.CreateTablesIfNotExists() err = dbmap.CreateTablesIfNotExists()
if err != nil { if err != nil {
log.Fatal(err) return nil, err
} }
DB = dbmap return dbmap, nil
} }

View File

@ -308,8 +308,8 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) {
return &gncimport, nil return &gncimport, nil
} }
func GnucashImportHandler(w http.ResponseWriter, r *http.Request) { func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(r) user, err := GetUserFromSession(db, r)
if err != nil { if err != nil {
WriteError(w, 1 /*Not Signed In*/) WriteError(w, 1 /*Not Signed In*/)
return return
@ -365,7 +365,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
sqltransaction, err := DB.Begin() sqltransaction, err := db.Begin()
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)

View File

@ -22,7 +22,7 @@ func (od *OFXDownload) Read(json_str string) error {
return dec.Decode(od) return dec.Decode(od)
} }
func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid int64) { func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
itl, err := ImportOFX(r) itl, err := ImportOFX(r)
if err != nil { if err != nil {
@ -38,7 +38,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i
return return
} }
sqltransaction, err := DB.Begin() sqltransaction, err := db.Begin()
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -258,7 +258,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i
WriteSuccess(w) WriteSuccess(w)
} }
func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) { func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
download_json := r.PostFormValue("ofxdownload") download_json := r.PostFormValue("ofxdownload")
if download_json == "" { if download_json == "" {
log.Print("download_json") log.Print("download_json")
@ -274,7 +274,7 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun
return return
} }
account, err := GetAccount(accountid, user.UserId) account, err := GetAccount(db, accountid, user.UserId)
if err != nil { if err != nil {
log.Print("GetAccount") log.Print("GetAccount")
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
@ -367,10 +367,10 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun
} }
defer response.Body.Close() defer response.Body.Close()
ofxImportHelper(response.Body, w, user, accountid) ofxImportHelper(db, response.Body, w, user, accountid)
} }
func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) { func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
multipartReader, err := r.MultipartReader() multipartReader, err := r.MultipartReader()
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
@ -390,19 +390,19 @@ func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac
return return
} }
ofxImportHelper(part, w, user, accountid) ofxImportHelper(db, part, w, user, accountid)
} }
/* /*
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated * Assumes the User is a valid, signed-in user, but accountid has not yet been validated
*/ */
func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) { func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
switch importtype { switch importtype {
case "ofx": case "ofx":
OFXImportHandler(w, r, user, accountid) OFXImportHandler(db, w, r, user, accountid)
case "ofxfile": case "ofxfile":
OFXFileImportHandler(w, r, user, accountid) OFXFileImportHandler(db, w, r, user, accountid)
default: default:
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
} }

44
main.go
View File

@ -4,6 +4,7 @@ package main
import ( import (
"flag" "flag"
"gopkg.in/gorp.v1"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -43,8 +44,6 @@ func init() {
// Setup the logging flags to be printed // Setup the logging flags to be printed
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
initDB(config)
} }
func rootHandler(w http.ResponseWriter, r *http.Request) { func rootHandler(w http.ResponseWriter, r *http.Request) {
@ -55,18 +54,39 @@ func staticHandler(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path.Join(config.MoneyGo.Basedir, r.URL.Path)) http.ServeFile(w, r, path.Join(config.MoneyGo.Basedir, r.URL.Path))
} }
func main() { // Create a closure over db, allowing the handlers to look like a
// http.HandlerFunc
type DB = gorp.DbMap
type DBHandler func(http.ResponseWriter, *http.Request, *DB)
func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h(w, r, db)
}
}
func GetHandler(db *DB) http.Handler {
servemux := http.NewServeMux() servemux := http.NewServeMux()
servemux.HandleFunc("/", rootHandler) servemux.HandleFunc("/", rootHandler)
servemux.HandleFunc("/static/", staticHandler) servemux.HandleFunc("/static/", staticHandler)
servemux.HandleFunc("/session/", SessionHandler) servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db))
servemux.HandleFunc("/user/", UserHandler) servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db))
servemux.HandleFunc("/security/", SecurityHandler) servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db))
servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler) servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler)
servemux.HandleFunc("/account/", AccountHandler) servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db))
servemux.HandleFunc("/transaction/", TransactionHandler) servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db))
servemux.HandleFunc("/import/gnucash", GnucashImportHandler) servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db))
servemux.HandleFunc("/report/", ReportHandler) servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db))
return servemux
}
func main() {
database, err := initDB(config)
if err != nil {
log.Fatal(err)
}
handler := GetHandler(database)
listener, err := net.Listen("tcp", ":"+strconv.Itoa(config.MoneyGo.Port)) listener, err := net.Listen("tcp", ":"+strconv.Itoa(config.MoneyGo.Port))
if err != nil { if err != nil {
@ -75,8 +95,8 @@ func main() {
log.Printf("Serving on port %d out of directory: %s", config.MoneyGo.Port, config.MoneyGo.Basedir) log.Printf("Serving on port %d out of directory: %s", config.MoneyGo.Port, config.MoneyGo.Basedir)
if config.MoneyGo.Fcgi { if config.MoneyGo.Fcgi {
fcgi.Serve(listener, servemux) fcgi.Serve(listener, handler)
} else { } else {
http.Serve(listener, servemux) http.Serve(listener, handler)
} }
} }

View File

@ -1,8 +1,6 @@
package main package main
import ( import (
"fmt"
"github.com/FlashBoys/go-finance"
"gopkg.in/gorp.v1" "gopkg.in/gorp.v1"
"time" "time"
) )
@ -93,8 +91,8 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi
} }
} }
func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, error) { func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -113,10 +111,3 @@ func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, err
return price, nil return price, nil
} }
func init() {
q, err := finance.GetQuote("BRK-A")
if err == nil {
fmt.Printf("%+v", q)
}
}

View File

@ -27,6 +27,7 @@ const (
accountsContextKey accountsContextKey
securitiesContextKey securitiesContextKey
balanceContextKey balanceContextKey
dbContextKey
) )
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
@ -76,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error {
return enc.Encode(r) return enc.Encode(r)
} }
func GetReport(reportid int64, userid int64) (*Report, error) { func GetReport(db *DB, reportid int64, userid int64) (*Report, error) {
var r Report var r Report
err := DB.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &r, nil return &r, nil
} }
func GetReports(userid int64) (*[]Report, error) { func GetReports(db *DB, userid int64) (*[]Report, error) {
var reports []Report var reports []Report
_, err := DB.Select(&reports, "SELECT * from reports where UserId=?", userid) _, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &reports, nil return &reports, nil
} }
func InsertReport(r *Report) error { func InsertReport(db *DB, r *Report) error {
err := DB.Insert(r) err := db.Insert(r)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func UpdateReport(r *Report) error { func UpdateReport(db *DB, r *Report) error {
count, err := DB.Update(r) count, err := db.Update(r)
if err != nil { if err != nil {
return err return err
} }
@ -115,8 +116,8 @@ func UpdateReport(r *Report) error {
return nil return nil
} }
func DeleteReport(r *Report) error { func DeleteReport(db *DB, r *Report) error {
count, err := DB.Delete(r) count, err := db.Delete(r)
if err != nil { if err != nil {
return err return err
} }
@ -126,13 +127,14 @@ func DeleteReport(r *Report) error {
return nil return nil
} }
func runReport(user *User, report *Report) (*Tabulation, error) { func runReport(db *DB, user *User, report *Report) (*Tabulation, error) {
// Create a new LState without opening the default libs for security // Create a new LState without opening the default libs for security
L := lua.NewState(lua.Options{SkipOpenLibs: true}) L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close() defer L.Close()
// Create a new context holding the current user with a timeout // Create a new context holding the current user with a timeout
ctx := context.WithValue(context.Background(), userContextKey, user) ctx := context.WithValue(context.Background(), userContextKey, user)
ctx = context.WithValue(ctx, dbContextKey, db)
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
defer cancel() defer cancel()
L.SetContext(ctx) L.SetContext(ctx)
@ -189,14 +191,14 @@ func runReport(user *User, report *Report) (*Tabulation, error) {
} }
} }
func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User, reportid int64) { func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
report, err := GetReport(reportid, user.UserId) report, err := GetReport(db, reportid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
tabulation, err := runReport(user, report) tabulation, err := runReport(db, user, report)
if err != nil { if err != nil {
// TODO handle different failure cases differently // TODO handle different failure cases differently
log.Print("runReport returned:", err) log.Print("runReport returned:", err)
@ -214,8 +216,8 @@ func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User,
} }
} }
func ReportHandler(w http.ResponseWriter, r *http.Request) { func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(r) user, err := GetUserFromSession(db, r)
if err != nil { if err != nil {
WriteError(w, 1 /*Not Signed In*/) WriteError(w, 1 /*Not Signed In*/)
return return
@ -237,7 +239,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
report.ReportId = -1 report.ReportId = -1
report.UserId = user.UserId report.UserId = user.UserId
err = InsertReport(&report) err = InsertReport(db, &report)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -260,7 +262,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
log.Print(err) log.Print(err)
return return
} }
ReportTabulationHandler(w, r, user, reportid) ReportTabulationHandler(db, w, r, user, reportid)
return return
} }
@ -269,7 +271,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
if err != nil || n != 1 { if err != nil || n != 1 {
//Return all Reports //Return all Reports
var rl ReportList var rl ReportList
reports, err := GetReports(user.UserId) reports, err := GetReports(db, user.UserId)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -284,7 +286,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
} }
} else { } else {
// Return Report with this Id // Return Report with this Id
report, err := GetReport(reportid, user.UserId) report, err := GetReport(db, reportid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
@ -319,7 +321,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
} }
report.UserId = user.UserId report.UserId = user.UserId
err = UpdateReport(&report) err = UpdateReport(db, &report)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -333,13 +335,13 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
report, err := GetReport(reportid, user.UserId) report, err := GetReport(db, reportid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
err = DeleteReport(report) err = DeleteReport(db, report)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)

View File

@ -96,10 +96,10 @@ func FindCurrencyTemplate(iso4217 int64) *Security {
return nil return nil
} }
func GetSecurity(securityid int64, userid int64) (*Security, error) { func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) {
var s Security var s Security
err := DB.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,18 +116,18 @@ func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64
return &s, nil return &s, nil
} }
func GetSecurities(userid int64) (*[]*Security, error) { func GetSecurities(db *DB, userid int64) (*[]*Security, error) {
var securities []*Security var securities []*Security
_, err := DB.Select(&securities, "SELECT * from securities where UserId=?", userid) _, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &securities, nil return &securities, nil
} }
func InsertSecurity(s *Security) error { func InsertSecurity(db *DB, s *Security) error {
err := DB.Insert(s) err := db.Insert(s)
if err != nil { if err != nil {
return err return err
} }
@ -142,8 +142,8 @@ func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error {
return nil return nil
} }
func UpdateSecurity(s *Security) error { func UpdateSecurity(db *DB, s *Security) error {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -176,8 +176,8 @@ func UpdateSecurity(s *Security) error {
return nil return nil
} }
func DeleteSecurity(s *Security) error { func DeleteSecurity(db *DB, s *Security) error {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -279,8 +279,8 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
return security, nil return security, nil
} }
func SecurityHandler(w http.ResponseWriter, r *http.Request) { func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(r) user, err := GetUserFromSession(db, r)
if err != nil { if err != nil {
WriteError(w, 1 /*Not Signed In*/) WriteError(w, 1 /*Not Signed In*/)
return return
@ -302,7 +302,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
security.SecurityId = -1 security.SecurityId = -1
security.UserId = user.UserId security.UserId = user.UserId
err = InsertSecurity(&security) err = InsertSecurity(db, &security)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -324,7 +324,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
//Return all securities //Return all securities
var sl SecurityList var sl SecurityList
securities, err := GetSecurities(user.UserId) securities, err := GetSecurities(db, user.UserId)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -339,7 +339,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
} else { } else {
security, err := GetSecurity(securityid, user.UserId) security, err := GetSecurity(db, securityid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
@ -373,7 +373,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
} }
security.UserId = user.UserId security.UserId = user.UserId
err = UpdateSecurity(&security) err = UpdateSecurity(db, &security)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -387,13 +387,13 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
security, err := GetSecurity(securityid, user.UserId) security, err := GetSecurity(db, securityid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
err = DeleteSecurity(security) err = DeleteSecurity(db, security)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)

View File

@ -13,14 +13,19 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
ctx := L.Context() ctx := L.Context()
security_map, ok := ctx.Value(securitiesContextKey).(map[int64]*Security) db, ok := ctx.Value(dbContextKey).(*DB)
if !ok {
return nil, errors.New("Couldn't find DB in lua's Context")
}
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
if !ok { if !ok {
user, ok := ctx.Value(userContextKey).(*User) user, ok := ctx.Value(userContextKey).(*User)
if !ok { if !ok {
return nil, errors.New("Couldn't find User in lua's Context") return nil, errors.New("Couldn't find User in lua's Context")
} }
securities, err := GetSecurities(user.UserId) securities, err := GetSecurities(db, user.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -149,7 +154,13 @@ func luaClosestPrice(L *lua.LState) int {
c := luaCheckSecurity(L, 2) c := luaCheckSecurity(L, 2)
date := luaCheckTime(L, 3) date := luaCheckTime(L, 3)
p, err := GetClosestPrice(s, c, date) ctx := L.Context()
db, ok := ctx.Value(dbContextKey).(*DB)
if !ok {
panic("Couldn't find DB in lua's Context")
}
p, err := GetClosestPrice(db, s, c, date)
if err != nil { if err != nil {
L.Push(lua.LNil) L.Push(lua.LNil)
} else { } else {

View File

@ -22,7 +22,7 @@ func (s *Session) Write(w http.ResponseWriter) error {
return enc.Encode(s) return enc.Encode(s)
} }
func GetSession(r *http.Request) (*Session, error) { func GetSession(db *DB, r *http.Request) (*Session, error) {
var s Session var s Session
cookie, err := r.Cookie("moneygo-session") cookie, err := r.Cookie("moneygo-session")
@ -31,17 +31,18 @@ func GetSession(r *http.Request) (*Session, error) {
} }
s.SessionSecret = cookie.Value s.SessionSecret = cookie.Value
err = DB.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &s, nil return &s, nil
} }
func DeleteSessionIfExists(r *http.Request) { func DeleteSessionIfExists(db *DB, r *http.Request) {
session, err := GetSession(r) // TODO do this in one transaction
session, err := GetSession(db, r)
if err == nil { if err == nil {
DB.Delete(session) db.Delete(session)
} }
} }
@ -53,7 +54,7 @@ func NewSessionCookie() (string, error) {
return base64.StdEncoding.EncodeToString(bits), nil return base64.StdEncoding.EncodeToString(bits), nil
} }
func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) { func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
s := Session{} s := Session{}
session_secret, err := NewSessionCookie() session_secret, err := NewSessionCookie()
@ -75,14 +76,14 @@ func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session,
s.SessionSecret = session_secret s.SessionSecret = session_secret
s.UserId = userid s.UserId = userid
err = DB.Insert(&s) err = db.Insert(&s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &s, nil return &s, nil
} }
func SessionHandler(w http.ResponseWriter, r *http.Request) { func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
if r.Method == "POST" || r.Method == "PUT" { if r.Method == "POST" || r.Method == "PUT" {
user_json := r.PostFormValue("user") user_json := r.PostFormValue("user")
if user_json == "" { if user_json == "" {
@ -97,7 +98,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
dbuser, err := GetUserByUsername(user.Username) dbuser, err := GetUserByUsername(db, user.Username)
if err != nil { if err != nil {
WriteError(w, 2 /*Unauthorized Access*/) WriteError(w, 2 /*Unauthorized Access*/)
return return
@ -109,9 +110,9 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
DeleteSessionIfExists(r) DeleteSessionIfExists(db, r)
session, err := NewSession(w, r, dbuser.UserId) session, err := NewSession(db, w, r, dbuser.UserId)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
return return
@ -124,7 +125,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
} else if r.Method == "GET" { } else if r.Method == "GET" {
s, err := GetSession(r) s, err := GetSession(db, r)
if err != nil { if err != nil {
WriteError(w, 1 /*Not Signed In*/) WriteError(w, 1 /*Not Signed In*/)
return return
@ -132,7 +133,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
s.Write(w) s.Write(w)
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
DeleteSessionIfExists(r) DeleteSessionIfExists(db, r)
WriteSuccess(w) WriteSuccess(w)
} }
} }

View File

@ -146,11 +146,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
if t.Splits[i].AccountId != -1 { if t.Splits[i].AccountId != -1 {
var err error var err error
var account *Account var account *Account
if transaction != nil {
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId) account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
} else {
account, err = GetAccount(t.Splits[i].AccountId, t.UserId)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -164,16 +160,12 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
return sums, nil return sums, nil
} }
func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
return t.GetImbalancesTx(nil)
}
// Returns true if all securities contained in this transaction are balanced, // Returns true if all securities contained in this transaction are balanced,
// false otherwise // false otherwise
func (t *Transaction) Balanced() (bool, error) { func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
var zero big.Rat var zero big.Rat
sums, err := t.GetImbalances() sums, err := t.GetImbalancesTx(transaction)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -186,21 +178,23 @@ func (t *Transaction) Balanced() (bool, error) {
return true, nil return true, nil
} }
func GetTransaction(transactionid int64, userid int64) (*Transaction, error) { func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) {
var t Transaction var t Transaction
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
if err != nil { if err != nil {
transaction.Rollback()
return nil, err return nil, err
} }
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) _, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
if err != nil { if err != nil {
transaction.Rollback()
return nil, err return nil, err
} }
@ -213,10 +207,10 @@ func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
return &t, nil return &t, nil
} }
func GetTransactions(userid int64) (*[]Transaction, error) { func GetTransactions(db *DB, userid int64) (*[]Transaction, error) {
var transactions []Transaction var transactions []Transaction
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -316,8 +310,8 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
return nil return nil
} }
func InsertTransaction(t *Transaction, user *User) error { func InsertTransaction(db *DB, t *Transaction, user *User) error {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -337,17 +331,11 @@ func InsertTransaction(t *Transaction, user *User) error {
return nil return nil
} }
func UpdateTransaction(t *Transaction, user *User) error { func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
transaction, err := DB.Begin()
if err != nil {
return err
}
var existing_splits []*Split var existing_splits []*Split
_, err = transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) _, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
if err != nil { if err != nil {
transaction.Rollback()
return err return err
} }
@ -367,11 +355,9 @@ func UpdateTransaction(t *Transaction, user *User) error {
if ok { if ok {
count, err := transaction.Update(t.Splits[i]) count, err := transaction.Update(t.Splits[i])
if err != nil { if err != nil {
transaction.Rollback()
return err return err
} }
if count != 1 { if count != 1 {
transaction.Rollback()
return errors.New("Updated more than one transaction split") return errors.New("Updated more than one transaction split")
} }
delete(s_map, t.Splits[i].SplitId) delete(s_map, t.Splits[i].SplitId)
@ -379,7 +365,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
t.Splits[i].SplitId = -1 t.Splits[i].SplitId = -1
err := transaction.Insert(t.Splits[i]) err := transaction.Insert(t.Splits[i])
if err != nil { if err != nil {
transaction.Rollback()
return err return err
} }
} }
@ -397,7 +382,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
if ok { if ok {
_, err := transaction.Delete(existing_splits[i]) _, err := transaction.Delete(existing_splits[i])
if err != nil { if err != nil {
transaction.Rollback()
return err return err
} }
} }
@ -410,31 +394,22 @@ func UpdateTransaction(t *Transaction, user *User) error {
} }
err = incrementAccountVersions(transaction, user, a_ids) err = incrementAccountVersions(transaction, user, a_ids)
if err != nil { if err != nil {
transaction.Rollback()
return err return err
} }
count, err := transaction.Update(t) count, err := transaction.Update(t)
if err != nil { if err != nil {
transaction.Rollback()
return err return err
} }
if count != 1 { if count != 1 {
transaction.Rollback()
return errors.New("Updated more than one transaction") return errors.New("Updated more than one transaction")
} }
err = transaction.Commit()
if err != nil {
transaction.Rollback()
return err
}
return nil return nil
} }
func DeleteTransaction(t *Transaction, user *User) error { func DeleteTransaction(db *DB, t *Transaction, user *User) error {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -477,8 +452,8 @@ func DeleteTransaction(t *Transaction, user *User) error {
return nil return nil
} }
func TransactionHandler(w http.ResponseWriter, r *http.Request) { func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(r) user, err := GetUserFromSession(db, r)
if err != nil { if err != nil {
WriteError(w, 1 /*Not Signed In*/) WriteError(w, 1 /*Not Signed In*/)
return return
@ -500,27 +475,37 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
transaction.TransactionId = -1 transaction.TransactionId = -1
transaction.UserId = user.UserId transaction.UserId = user.UserId
balanced, err := transaction.Balanced() sqltx, err := db.Begin()
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
return return
} }
balanced, err := transaction.Balanced(sqltx)
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
if !transaction.Valid() || !balanced { if !transaction.Valid() || !balanced {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
for i := range transaction.Splits { for i := range transaction.Splits {
transaction.Splits[i].SplitId = -1 transaction.Splits[i].SplitId = -1
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId) _, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
if err != nil { if err != nil {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
} }
err = InsertTransaction(&transaction, user) err = InsertTransactionTx(sqltx, &transaction, user)
if err != nil { if err != nil {
if _, ok := err.(AccountMissingError); ok { if _, ok := err.(AccountMissingError); ok {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
@ -528,6 +513,15 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
} }
sqltx.Rollback()
return
}
err = sqltx.Commit()
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return return
} }
@ -543,7 +537,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
//Return all Transactions //Return all Transactions
var al TransactionList var al TransactionList
transactions, err := GetTransactions(user.UserId) transactions, err := GetTransactions(db, user.UserId)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -558,7 +552,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
} }
} else { } else {
//Return Transaction with this Id //Return Transaction with this Id
transaction, err := GetTransaction(transactionid, user.UserId) transaction, err := GetTransaction(db, transactionid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
@ -591,27 +585,46 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
} }
transaction.UserId = user.UserId transaction.UserId = user.UserId
balanced, err := transaction.Balanced() sqltx, err := db.Begin()
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
return return
} }
balanced, err := transaction.Balanced(sqltx)
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
if !transaction.Valid() || !balanced { if !transaction.Valid() || !balanced {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
for i := range transaction.Splits { for i := range transaction.Splits {
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId) _, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
if err != nil { if err != nil {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
} }
err = UpdateTransaction(&transaction, user) err = UpdateTransactionTx(sqltx, &transaction, user)
if err != nil { if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
err = sqltx.Commit()
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
return return
@ -630,13 +643,13 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
transaction, err := GetTransaction(transactionid, user.UserId) transaction, err := GetTransaction(db, transactionid, user.UserId)
if err != nil { if err != nil {
WriteError(w, 3 /*Invalid Request*/) WriteError(w, 3 /*Invalid Request*/)
return return
} }
err = DeleteTransaction(transaction, user) err = DeleteTransaction(db, transaction, user)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -672,9 +685,9 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6
return &pageDifference, nil return &pageDifference, nil
} }
func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) { func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
var splits []Split var splits []Split
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -707,9 +720,9 @@ func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
} }
// Assumes accountid is valid and is owned by the current user // Assumes accountid is valid and is owned by the current user
func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.Rat, error) { func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
var splits []Split var splits []Split
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -741,9 +754,9 @@ func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.R
return &balance, nil return &balance, nil
} }
func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
var splits []Split var splits []Split
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -775,11 +788,11 @@ func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Ti
return &balance, nil return &balance, nil
} }
func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
var transactions []Transaction var transactions []Transaction
var atl AccountTransactionsList var atl AccountTransactionsList
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -878,7 +891,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6
// Return only those transactions which have at least one split pertaining to // Return only those transactions which have at least one split pertaining to
// an account // an account
func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request, func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
user *User, accountid int64) { user *User, accountid int64) {
var page uint64 = 0 var page uint64 = 0
@ -916,7 +929,7 @@ func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
sort = sortstring sort = sortstring
} }
accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit) accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)

View File

@ -47,10 +47,10 @@ func (u *User) HashPassword() {
u.Password = "" u.Password = ""
} }
func GetUser(userid int64) (*User, error) { func GetUser(db *DB, userid int64) (*User, error) {
var u User var u User
err := DB.SelectOne(&u, "SELECT * from users where UserId=?", userid) err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -67,18 +67,18 @@ func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) {
return &u, nil return &u, nil
} }
func GetUserByUsername(username string) (*User, error) { func GetUserByUsername(db *DB, username string) (*User, error) {
var u User var u User
err := DB.SelectOne(&u, "SELECT * from users where Username=?", username) err := db.SelectOne(&u, "SELECT * from users where Username=?", username)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &u, nil return &u, nil
} }
func InsertUser(u *User) error { func InsertUser(db *DB, u *User) error {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -136,16 +136,16 @@ func InsertUser(u *User) error {
return nil return nil
} }
func GetUserFromSession(r *http.Request) (*User, error) { func GetUserFromSession(db *DB, r *http.Request) (*User, error) {
s, err := GetSession(r) s, err := GetSession(db, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return GetUser(s.UserId) return GetUser(db, s.UserId)
} }
func UpdateUser(u *User) error { func UpdateUser(db *DB, u *User) error {
transaction, err := DB.Begin() transaction, err := db.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -180,7 +180,7 @@ func UpdateUser(u *User) error {
return nil return nil
} }
func UserHandler(w http.ResponseWriter, r *http.Request) { func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
if r.Method == "POST" { if r.Method == "POST" {
user_json := r.PostFormValue("user") user_json := r.PostFormValue("user")
if user_json == "" { if user_json == "" {
@ -197,7 +197,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
user.UserId = -1 user.UserId = -1
user.HashPassword() user.HashPassword()
err = InsertUser(&user) err = InsertUser(db, &user)
if err != nil { if err != nil {
if _, ok := err.(UserExistsError); ok { if _, ok := err.(UserExistsError); ok {
WriteError(w, 4 /*User Exists*/) WriteError(w, 4 /*User Exists*/)
@ -216,7 +216,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
} else { } else {
user, err := GetUserFromSession(r) user, err := GetUserFromSession(db, r)
if err != nil { if err != nil {
WriteError(w, 1 /*Not Signed In*/) WriteError(w, 1 /*Not Signed In*/)
return return
@ -264,7 +264,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
user.PasswordHash = old_pwhash user.PasswordHash = old_pwhash
} }
err = UpdateUser(user) err = UpdateUser(db, user)
if err != nil { if err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)
@ -278,7 +278,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
count, err := DB.Delete(&user) count, err := db.Delete(&user)
if count != 1 || err != nil { if count != 1 || err != nil {
WriteError(w, 999 /*Internal Error*/) WriteError(w, 999 /*Internal Error*/)
log.Print(err) log.Print(err)