mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-12-27 07:52:28 -05:00
Pass DB as a closure instead of a global variable
This is part of an ongoing attempt to restructure the code to make it more 'testable'.
This commit is contained in:
parent
9abafa50b2
commit
156b9aaf0c
48
accounts.go
48
accounts.go
@ -124,10 +124,10 @@ func (al *AccountList) Write(w http.ResponseWriter) error {
|
|||||||
return enc.Encode(al)
|
return enc.Encode(al)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccount(accountid int64, userid int64) (*Account, error) {
|
func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) {
|
||||||
var a Account
|
var a Account
|
||||||
|
|
||||||
err := DB.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -145,10 +145,10 @@ func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64)
|
|||||||
return &a, nil
|
return &a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccounts(userid int64) (*[]Account, error) {
|
func GetAccounts(db *DB, userid int64) (*[]Account, error) {
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
|
|
||||||
_, err := DB.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
_, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -276,8 +276,8 @@ func (pame ParentAccountMissingError) Error() string {
|
|||||||
return "Parent account missing"
|
return "Parent account missing"
|
||||||
}
|
}
|
||||||
|
|
||||||
func insertUpdateAccount(a *Account, insert bool) error {
|
func insertUpdateAccount(db *DB, a *Account, insert bool) error {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -329,16 +329,16 @@ func insertUpdateAccount(a *Account, insert bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertAccount(a *Account) error {
|
func InsertAccount(db *DB, a *Account) error {
|
||||||
return insertUpdateAccount(a, true)
|
return insertUpdateAccount(db, a, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAccount(a *Account) error {
|
func UpdateAccount(db *DB, a *Account) error {
|
||||||
return insertUpdateAccount(a, false)
|
return insertUpdateAccount(db, a, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteAccount(a *Account) error {
|
func DeleteAccount(db *DB, a *Account) error {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -385,8 +385,8 @@ func DeleteAccount(a *Account) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||||
user, err := GetUserFromSession(r)
|
user, err := GetUserFromSession(db, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
WriteError(w, 1 /*Not Signed In*/)
|
||||||
return
|
return
|
||||||
@ -405,7 +405,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
AccountImportHandler(w, r, user, accountid, importtype)
|
AccountImportHandler(db, w, r, user, accountid, importtype)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -425,7 +425,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
account.UserId = user.UserId
|
account.UserId = user.UserId
|
||||||
account.AccountVersion = 0
|
account.AccountVersion = 0
|
||||||
|
|
||||||
security, err := GetSecurity(account.SecurityId, user.UserId)
|
security, err := GetSecurity(db, account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -436,7 +436,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = InsertAccount(&account)
|
err = InsertAccount(db, &account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(ParentAccountMissingError); ok {
|
if _, ok := err.(ParentAccountMissingError); ok {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
@ -461,7 +461,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil || n != 1 {
|
if err != nil || n != 1 {
|
||||||
//Return all Accounts
|
//Return all Accounts
|
||||||
var al AccountList
|
var al AccountList
|
||||||
accounts, err := GetAccounts(user.UserId)
|
accounts, err := GetAccounts(db, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -478,12 +478,12 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// if URL looks like /account/[0-9]+/transactions, use the account
|
// if URL looks like /account/[0-9]+/transactions, use the account
|
||||||
// transaction handler
|
// transaction handler
|
||||||
if accountTransactionsRE.MatchString(r.URL.Path) {
|
if accountTransactionsRE.MatchString(r.URL.Path) {
|
||||||
AccountTransactionsHandler(w, r, user, accountid)
|
AccountTransactionsHandler(db, w, r, user, accountid)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return Account with this Id
|
// Return Account with this Id
|
||||||
account, err := GetAccount(accountid, user.UserId)
|
account, err := GetAccount(db, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
@ -517,7 +517,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
account.UserId = user.UserId
|
account.UserId = user.UserId
|
||||||
|
|
||||||
security, err := GetSecurity(account.SecurityId, user.UserId)
|
security, err := GetSecurity(db, account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -528,7 +528,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateAccount(&account)
|
err = UpdateAccount(db, &account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -542,13 +542,13 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
account, err := GetAccount(accountid, user.UserId)
|
account, err := GetAccount(db, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteAccount(account)
|
err = DeleteAccount(db, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
@ -15,14 +15,19 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
account_map, ok := ctx.Value(accountsContextKey).(map[int64]*Account)
|
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("Couldn't find DB in lua's Context")
|
||||||
|
}
|
||||||
|
|
||||||
|
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
|
||||||
if !ok {
|
if !ok {
|
||||||
user, ok := ctx.Value(userContextKey).(*User)
|
user, ok := ctx.Value(userContextKey).(*User)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := GetAccounts(user.UserId)
|
accounts, err := GetAccounts(db, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -144,6 +149,10 @@ func luaAccountBalance(L *lua.LState) int {
|
|||||||
a := luaCheckAccount(L, 1)
|
a := luaCheckAccount(L, 1)
|
||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||||
|
if !ok {
|
||||||
|
panic("Couldn't find DB in lua's Context")
|
||||||
|
}
|
||||||
user, ok := ctx.Value(userContextKey).(*User)
|
user, ok := ctx.Value(userContextKey).(*User)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find User in lua's Context")
|
panic("Couldn't find User in lua's Context")
|
||||||
@ -162,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int {
|
|||||||
if date != nil {
|
if date != nil {
|
||||||
end := luaWeakCheckTime(L, 3)
|
end := luaWeakCheckTime(L, 3)
|
||||||
if end != nil {
|
if end != nil {
|
||||||
rat, err = GetAccountBalanceDateRange(user, a.AccountId, date, end)
|
rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end)
|
||||||
} else {
|
} else {
|
||||||
rat, err = GetAccountBalanceDate(user, a.AccountId, date)
|
rat, err = GetAccountBalanceDate(db, user, a.AccountId, date)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
rat, err = GetAccountBalance(user, a.AccountId)
|
rat, err = GetAccountBalance(db, user, a.AccountId)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("Failed to GetAccountBalance:" + err.Error())
|
panic("Failed to GetAccountBalance:" + err.Error())
|
||||||
|
14
db.go
14
db.go
@ -2,19 +2,17 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"gopkg.in/gorp.v1"
|
"gopkg.in/gorp.v1"
|
||||||
"log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var DB *gorp.DbMap
|
func initDB(cfg *Config) (*gorp.DbMap, error) {
|
||||||
|
|
||||||
func initDB(cfg *Config) {
|
|
||||||
db, err := sql.Open(cfg.MoneyGo.DBType.String(), cfg.MoneyGo.DSN)
|
db, err := sql.Open(cfg.MoneyGo.DBType.String(), cfg.MoneyGo.DSN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialect gorp.Dialect
|
var dialect gorp.Dialect
|
||||||
@ -28,7 +26,7 @@ func initDB(cfg *Config) {
|
|||||||
} else if cfg.MoneyGo.DBType == Postgres {
|
} else if cfg.MoneyGo.DBType == Postgres {
|
||||||
dialect = gorp.PostgresDialect{}
|
dialect = gorp.PostgresDialect{}
|
||||||
} else {
|
} else {
|
||||||
log.Fatalf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String())
|
return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
||||||
@ -43,8 +41,8 @@ func initDB(cfg *Config) {
|
|||||||
|
|
||||||
err = dbmap.CreateTablesIfNotExists()
|
err = dbmap.CreateTablesIfNotExists()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
DB = dbmap
|
return dbmap, nil
|
||||||
}
|
}
|
||||||
|
@ -308,8 +308,8 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) {
|
|||||||
return &gncimport, nil
|
return &gncimport, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GnucashImportHandler(w http.ResponseWriter, r *http.Request) {
|
func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||||
user, err := GetUserFromSession(r)
|
user, err := GetUserFromSession(db, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
WriteError(w, 1 /*Not Signed In*/)
|
||||||
return
|
return
|
||||||
@ -365,7 +365,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sqltransaction, err := DB.Begin()
|
sqltransaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
20
imports.go
20
imports.go
@ -22,7 +22,7 @@ func (od *OFXDownload) Read(json_str string) error {
|
|||||||
return dec.Decode(od)
|
return dec.Decode(od)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
|
func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
|
||||||
itl, err := ImportOFX(r)
|
itl, err := ImportOFX(r)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -38,7 +38,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sqltransaction, err := DB.Begin()
|
sqltransaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -258,7 +258,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i
|
|||||||
WriteSuccess(w)
|
WriteSuccess(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
||||||
download_json := r.PostFormValue("ofxdownload")
|
download_json := r.PostFormValue("ofxdownload")
|
||||||
if download_json == "" {
|
if download_json == "" {
|
||||||
log.Print("download_json")
|
log.Print("download_json")
|
||||||
@ -274,7 +274,7 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := GetAccount(accountid, user.UserId)
|
account, err := GetAccount(db, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print("GetAccount")
|
log.Print("GetAccount")
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
@ -367,10 +367,10 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun
|
|||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
ofxImportHelper(response.Body, w, user, accountid)
|
ofxImportHelper(db, response.Body, w, user, accountid)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
||||||
multipartReader, err := r.MultipartReader()
|
multipartReader, err := r.MultipartReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
@ -390,19 +390,19 @@ func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ofxImportHelper(part, w, user, accountid)
|
ofxImportHelper(db, part, w, user, accountid)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
|
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
|
||||||
*/
|
*/
|
||||||
func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
|
func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
|
||||||
|
|
||||||
switch importtype {
|
switch importtype {
|
||||||
case "ofx":
|
case "ofx":
|
||||||
OFXImportHandler(w, r, user, accountid)
|
OFXImportHandler(db, w, r, user, accountid)
|
||||||
case "ofxfile":
|
case "ofxfile":
|
||||||
OFXFileImportHandler(w, r, user, accountid)
|
OFXFileImportHandler(db, w, r, user, accountid)
|
||||||
default:
|
default:
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
44
main.go
44
main.go
@ -4,6 +4,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"gopkg.in/gorp.v1"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -43,8 +44,6 @@ func init() {
|
|||||||
|
|
||||||
// Setup the logging flags to be printed
|
// Setup the logging flags to be printed
|
||||||
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
|
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
|
||||||
|
|
||||||
initDB(config)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func rootHandler(w http.ResponseWriter, r *http.Request) {
|
func rootHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -55,18 +54,39 @@ func staticHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.ServeFile(w, r, path.Join(config.MoneyGo.Basedir, r.URL.Path))
|
http.ServeFile(w, r, path.Join(config.MoneyGo.Basedir, r.URL.Path))
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
// Create a closure over db, allowing the handlers to look like a
|
||||||
|
// http.HandlerFunc
|
||||||
|
type DB = gorp.DbMap
|
||||||
|
type DBHandler func(http.ResponseWriter, *http.Request, *DB)
|
||||||
|
|
||||||
|
func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h(w, r, db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetHandler(db *DB) http.Handler {
|
||||||
servemux := http.NewServeMux()
|
servemux := http.NewServeMux()
|
||||||
servemux.HandleFunc("/", rootHandler)
|
servemux.HandleFunc("/", rootHandler)
|
||||||
servemux.HandleFunc("/static/", staticHandler)
|
servemux.HandleFunc("/static/", staticHandler)
|
||||||
servemux.HandleFunc("/session/", SessionHandler)
|
servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db))
|
||||||
servemux.HandleFunc("/user/", UserHandler)
|
servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db))
|
||||||
servemux.HandleFunc("/security/", SecurityHandler)
|
servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db))
|
||||||
servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler)
|
servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler)
|
||||||
servemux.HandleFunc("/account/", AccountHandler)
|
servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db))
|
||||||
servemux.HandleFunc("/transaction/", TransactionHandler)
|
servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db))
|
||||||
servemux.HandleFunc("/import/gnucash", GnucashImportHandler)
|
servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db))
|
||||||
servemux.HandleFunc("/report/", ReportHandler)
|
servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db))
|
||||||
|
|
||||||
|
return servemux
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
database, err := initDB(config)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
handler := GetHandler(database)
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", ":"+strconv.Itoa(config.MoneyGo.Port))
|
listener, err := net.Listen("tcp", ":"+strconv.Itoa(config.MoneyGo.Port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -75,8 +95,8 @@ func main() {
|
|||||||
|
|
||||||
log.Printf("Serving on port %d out of directory: %s", config.MoneyGo.Port, config.MoneyGo.Basedir)
|
log.Printf("Serving on port %d out of directory: %s", config.MoneyGo.Port, config.MoneyGo.Basedir)
|
||||||
if config.MoneyGo.Fcgi {
|
if config.MoneyGo.Fcgi {
|
||||||
fcgi.Serve(listener, servemux)
|
fcgi.Serve(listener, handler)
|
||||||
} else {
|
} else {
|
||||||
http.Serve(listener, servemux)
|
http.Serve(listener, handler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
13
prices.go
13
prices.go
@ -1,8 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/FlashBoys/go-finance"
|
|
||||||
"gopkg.in/gorp.v1"
|
"gopkg.in/gorp.v1"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -93,8 +91,8 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, error) {
|
func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -113,10 +111,3 @@ func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, err
|
|||||||
|
|
||||||
return price, nil
|
return price, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
|
||||||
q, err := finance.GetQuote("BRK-A")
|
|
||||||
if err == nil {
|
|
||||||
fmt.Printf("%+v", q)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
48
reports.go
48
reports.go
@ -27,6 +27,7 @@ const (
|
|||||||
accountsContextKey
|
accountsContextKey
|
||||||
securitiesContextKey
|
securitiesContextKey
|
||||||
balanceContextKey
|
balanceContextKey
|
||||||
|
dbContextKey
|
||||||
)
|
)
|
||||||
|
|
||||||
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
||||||
@ -76,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error {
|
|||||||
return enc.Encode(r)
|
return enc.Encode(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetReport(reportid int64, userid int64) (*Report, error) {
|
func GetReport(db *DB, reportid int64, userid int64) (*Report, error) {
|
||||||
var r Report
|
var r Report
|
||||||
|
|
||||||
err := DB.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &r, nil
|
return &r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetReports(userid int64) (*[]Report, error) {
|
func GetReports(db *DB, userid int64) (*[]Report, error) {
|
||||||
var reports []Report
|
var reports []Report
|
||||||
|
|
||||||
_, err := DB.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
_, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &reports, nil
|
return &reports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertReport(r *Report) error {
|
func InsertReport(db *DB, r *Report) error {
|
||||||
err := DB.Insert(r)
|
err := db.Insert(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateReport(r *Report) error {
|
func UpdateReport(db *DB, r *Report) error {
|
||||||
count, err := DB.Update(r)
|
count, err := db.Update(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -115,8 +116,8 @@ func UpdateReport(r *Report) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteReport(r *Report) error {
|
func DeleteReport(db *DB, r *Report) error {
|
||||||
count, err := DB.Delete(r)
|
count, err := db.Delete(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -126,13 +127,14 @@ func DeleteReport(r *Report) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runReport(user *User, report *Report) (*Tabulation, error) {
|
func runReport(db *DB, user *User, report *Report) (*Tabulation, error) {
|
||||||
// Create a new LState without opening the default libs for security
|
// Create a new LState without opening the default libs for security
|
||||||
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||||
defer L.Close()
|
defer L.Close()
|
||||||
|
|
||||||
// Create a new context holding the current user with a timeout
|
// Create a new context holding the current user with a timeout
|
||||||
ctx := context.WithValue(context.Background(), userContextKey, user)
|
ctx := context.WithValue(context.Background(), userContextKey, user)
|
||||||
|
ctx = context.WithValue(ctx, dbContextKey, db)
|
||||||
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
L.SetContext(ctx)
|
L.SetContext(ctx)
|
||||||
@ -189,14 +191,14 @@ func runReport(user *User, report *Report) (*Tabulation, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
|
func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
|
||||||
report, err := GetReport(reportid, user.UserId)
|
report, err := GetReport(db, reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tabulation, err := runReport(user, report)
|
tabulation, err := runReport(db, user, report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO handle different failure cases differently
|
// TODO handle different failure cases differently
|
||||||
log.Print("runReport returned:", err)
|
log.Print("runReport returned:", err)
|
||||||
@ -214,8 +216,8 @@ func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||||
user, err := GetUserFromSession(r)
|
user, err := GetUserFromSession(db, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
WriteError(w, 1 /*Not Signed In*/)
|
||||||
return
|
return
|
||||||
@ -237,7 +239,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
report.ReportId = -1
|
report.ReportId = -1
|
||||||
report.UserId = user.UserId
|
report.UserId = user.UserId
|
||||||
|
|
||||||
err = InsertReport(&report)
|
err = InsertReport(db, &report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -260,7 +262,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ReportTabulationHandler(w, r, user, reportid)
|
ReportTabulationHandler(db, w, r, user, reportid)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,7 +271,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil || n != 1 {
|
if err != nil || n != 1 {
|
||||||
//Return all Reports
|
//Return all Reports
|
||||||
var rl ReportList
|
var rl ReportList
|
||||||
reports, err := GetReports(user.UserId)
|
reports, err := GetReports(db, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -284,7 +286,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Return Report with this Id
|
// Return Report with this Id
|
||||||
report, err := GetReport(reportid, user.UserId)
|
report, err := GetReport(db, reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
@ -319,7 +321,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
report.UserId = user.UserId
|
report.UserId = user.UserId
|
||||||
|
|
||||||
err = UpdateReport(&report)
|
err = UpdateReport(db, &report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -333,13 +335,13 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
report, err := GetReport(reportid, user.UserId)
|
report, err := GetReport(db, reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteReport(report)
|
err = DeleteReport(db, report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
@ -96,10 +96,10 @@ func FindCurrencyTemplate(iso4217 int64) *Security {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSecurity(securityid int64, userid int64) (*Security, error) {
|
func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) {
|
||||||
var s Security
|
var s Security
|
||||||
|
|
||||||
err := DB.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -116,18 +116,18 @@ func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64
|
|||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSecurities(userid int64) (*[]*Security, error) {
|
func GetSecurities(db *DB, userid int64) (*[]*Security, error) {
|
||||||
var securities []*Security
|
var securities []*Security
|
||||||
|
|
||||||
_, err := DB.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
_, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &securities, nil
|
return &securities, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertSecurity(s *Security) error {
|
func InsertSecurity(db *DB, s *Security) error {
|
||||||
err := DB.Insert(s)
|
err := db.Insert(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -142,8 +142,8 @@ func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateSecurity(s *Security) error {
|
func UpdateSecurity(db *DB, s *Security) error {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -176,8 +176,8 @@ func UpdateSecurity(s *Security) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSecurity(s *Security) error {
|
func DeleteSecurity(db *DB, s *Security) error {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -279,8 +279,8 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
|
|||||||
return security, nil
|
return security, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||||
user, err := GetUserFromSession(r)
|
user, err := GetUserFromSession(db, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
WriteError(w, 1 /*Not Signed In*/)
|
||||||
return
|
return
|
||||||
@ -302,7 +302,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
security.SecurityId = -1
|
security.SecurityId = -1
|
||||||
security.UserId = user.UserId
|
security.UserId = user.UserId
|
||||||
|
|
||||||
err = InsertSecurity(&security)
|
err = InsertSecurity(db, &security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -324,7 +324,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
//Return all securities
|
//Return all securities
|
||||||
var sl SecurityList
|
var sl SecurityList
|
||||||
|
|
||||||
securities, err := GetSecurities(user.UserId)
|
securities, err := GetSecurities(db, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -339,7 +339,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
security, err := GetSecurity(securityid, user.UserId)
|
security, err := GetSecurity(db, securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
@ -373,7 +373,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
security.UserId = user.UserId
|
security.UserId = user.UserId
|
||||||
|
|
||||||
err = UpdateSecurity(&security)
|
err = UpdateSecurity(db, &security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -387,13 +387,13 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
security, err := GetSecurity(securityid, user.UserId)
|
security, err := GetSecurity(db, securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteSecurity(security)
|
err = DeleteSecurity(db, security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
@ -13,14 +13,19 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
security_map, ok := ctx.Value(securitiesContextKey).(map[int64]*Security)
|
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("Couldn't find DB in lua's Context")
|
||||||
|
}
|
||||||
|
|
||||||
|
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
|
||||||
if !ok {
|
if !ok {
|
||||||
user, ok := ctx.Value(userContextKey).(*User)
|
user, ok := ctx.Value(userContextKey).(*User)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
securities, err := GetSecurities(user.UserId)
|
securities, err := GetSecurities(db, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -149,7 +154,13 @@ func luaClosestPrice(L *lua.LState) int {
|
|||||||
c := luaCheckSecurity(L, 2)
|
c := luaCheckSecurity(L, 2)
|
||||||
date := luaCheckTime(L, 3)
|
date := luaCheckTime(L, 3)
|
||||||
|
|
||||||
p, err := GetClosestPrice(s, c, date)
|
ctx := L.Context()
|
||||||
|
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||||
|
if !ok {
|
||||||
|
panic("Couldn't find DB in lua's Context")
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := GetClosestPrice(db, s, c, date)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
L.Push(lua.LNil)
|
L.Push(lua.LNil)
|
||||||
} else {
|
} else {
|
||||||
|
27
sessions.go
27
sessions.go
@ -22,7 +22,7 @@ func (s *Session) Write(w http.ResponseWriter) error {
|
|||||||
return enc.Encode(s)
|
return enc.Encode(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSession(r *http.Request) (*Session, error) {
|
func GetSession(db *DB, r *http.Request) (*Session, error) {
|
||||||
var s Session
|
var s Session
|
||||||
|
|
||||||
cookie, err := r.Cookie("moneygo-session")
|
cookie, err := r.Cookie("moneygo-session")
|
||||||
@ -31,17 +31,18 @@ func GetSession(r *http.Request) (*Session, error) {
|
|||||||
}
|
}
|
||||||
s.SessionSecret = cookie.Value
|
s.SessionSecret = cookie.Value
|
||||||
|
|
||||||
err = DB.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSessionIfExists(r *http.Request) {
|
func DeleteSessionIfExists(db *DB, r *http.Request) {
|
||||||
session, err := GetSession(r)
|
// TODO do this in one transaction
|
||||||
|
session, err := GetSession(db, r)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
DB.Delete(session)
|
db.Delete(session)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,7 +54,7 @@ func NewSessionCookie() (string, error) {
|
|||||||
return base64.StdEncoding.EncodeToString(bits), nil
|
return base64.StdEncoding.EncodeToString(bits), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
|
func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
|
||||||
s := Session{}
|
s := Session{}
|
||||||
|
|
||||||
session_secret, err := NewSessionCookie()
|
session_secret, err := NewSessionCookie()
|
||||||
@ -75,14 +76,14 @@ func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session,
|
|||||||
s.SessionSecret = session_secret
|
s.SessionSecret = session_secret
|
||||||
s.UserId = userid
|
s.UserId = userid
|
||||||
|
|
||||||
err = DB.Insert(&s)
|
err = db.Insert(&s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||||
if r.Method == "POST" || r.Method == "PUT" {
|
if r.Method == "POST" || r.Method == "PUT" {
|
||||||
user_json := r.PostFormValue("user")
|
user_json := r.PostFormValue("user")
|
||||||
if user_json == "" {
|
if user_json == "" {
|
||||||
@ -97,7 +98,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dbuser, err := GetUserByUsername(user.Username)
|
dbuser, err := GetUserByUsername(db, user.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 2 /*Unauthorized Access*/)
|
WriteError(w, 2 /*Unauthorized Access*/)
|
||||||
return
|
return
|
||||||
@ -109,9 +110,9 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
DeleteSessionIfExists(r)
|
DeleteSessionIfExists(db, r)
|
||||||
|
|
||||||
session, err := NewSession(w, r, dbuser.UserId)
|
session, err := NewSession(db, w, r, dbuser.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
return
|
return
|
||||||
@ -124,7 +125,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else if r.Method == "GET" {
|
} else if r.Method == "GET" {
|
||||||
s, err := GetSession(r)
|
s, err := GetSession(db, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
WriteError(w, 1 /*Not Signed In*/)
|
||||||
return
|
return
|
||||||
@ -132,7 +133,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
s.Write(w)
|
s.Write(w)
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
DeleteSessionIfExists(r)
|
DeleteSessionIfExists(db, r)
|
||||||
WriteSuccess(w)
|
WriteSuccess(w)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
135
transactions.go
135
transactions.go
@ -146,11 +146,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
|
|||||||
if t.Splits[i].AccountId != -1 {
|
if t.Splits[i].AccountId != -1 {
|
||||||
var err error
|
var err error
|
||||||
var account *Account
|
var account *Account
|
||||||
if transaction != nil {
|
|
||||||
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
|
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
|
||||||
} else {
|
|
||||||
account, err = GetAccount(t.Splits[i].AccountId, t.UserId)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -164,16 +160,12 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
|
|||||||
return sums, nil
|
return sums, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
|
|
||||||
return t.GetImbalancesTx(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if all securities contained in this transaction are balanced,
|
// Returns true if all securities contained in this transaction are balanced,
|
||||||
// false otherwise
|
// false otherwise
|
||||||
func (t *Transaction) Balanced() (bool, error) {
|
func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
|
||||||
var zero big.Rat
|
var zero big.Rat
|
||||||
|
|
||||||
sums, err := t.GetImbalances()
|
sums, err := t.GetImbalancesTx(transaction)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -186,21 +178,23 @@ func (t *Transaction) Balanced() (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
|
func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) {
|
||||||
var t Transaction
|
var t Transaction
|
||||||
|
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
transaction.Rollback()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
transaction.Rollback()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,10 +207,10 @@ func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
|
|||||||
return &t, nil
|
return &t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTransactions(userid int64) (*[]Transaction, error) {
|
func GetTransactions(db *DB, userid int64) (*[]Transaction, error) {
|
||||||
var transactions []Transaction
|
var transactions []Transaction
|
||||||
|
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -316,8 +310,8 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertTransaction(t *Transaction, user *User) error {
|
func InsertTransaction(db *DB, t *Transaction, user *User) error {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -337,17 +331,11 @@ func InsertTransaction(t *Transaction, user *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateTransaction(t *Transaction, user *User) error {
|
func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
|
||||||
transaction, err := DB.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var existing_splits []*Split
|
var existing_splits []*Split
|
||||||
|
|
||||||
_, err = transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -367,11 +355,9 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
|||||||
if ok {
|
if ok {
|
||||||
count, err := transaction.Update(t.Splits[i])
|
count, err := transaction.Update(t.Splits[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Updated more than one transaction split")
|
return errors.New("Updated more than one transaction split")
|
||||||
}
|
}
|
||||||
delete(s_map, t.Splits[i].SplitId)
|
delete(s_map, t.Splits[i].SplitId)
|
||||||
@ -379,7 +365,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
|||||||
t.Splits[i].SplitId = -1
|
t.Splits[i].SplitId = -1
|
||||||
err := transaction.Insert(t.Splits[i])
|
err := transaction.Insert(t.Splits[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -397,7 +382,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
|||||||
if ok {
|
if ok {
|
||||||
_, err := transaction.Delete(existing_splits[i])
|
_, err := transaction.Delete(existing_splits[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -410,31 +394,22 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
|||||||
}
|
}
|
||||||
err = incrementAccountVersions(transaction, user, a_ids)
|
err = incrementAccountVersions(transaction, user, a_ids)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.Update(t)
|
count, err := transaction.Update(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Updated more than one transaction")
|
return errors.New("Updated more than one transaction")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteTransaction(t *Transaction, user *User) error {
|
func DeleteTransaction(db *DB, t *Transaction, user *User) error {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -477,8 +452,8 @@ func DeleteTransaction(t *Transaction, user *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||||
user, err := GetUserFromSession(r)
|
user, err := GetUserFromSession(db, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
WriteError(w, 1 /*Not Signed In*/)
|
||||||
return
|
return
|
||||||
@ -500,27 +475,37 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
transaction.TransactionId = -1
|
transaction.TransactionId = -1
|
||||||
transaction.UserId = user.UserId
|
transaction.UserId = user.UserId
|
||||||
|
|
||||||
balanced, err := transaction.Balanced()
|
sqltx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
balanced, err := transaction.Balanced(sqltx)
|
||||||
|
if err != nil {
|
||||||
|
sqltx.Rollback()
|
||||||
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
|
log.Print(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
if !transaction.Valid() || !balanced {
|
if !transaction.Valid() || !balanced {
|
||||||
|
sqltx.Rollback()
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range transaction.Splits {
|
for i := range transaction.Splits {
|
||||||
transaction.Splits[i].SplitId = -1
|
transaction.Splits[i].SplitId = -1
|
||||||
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
sqltx.Rollback()
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = InsertTransaction(&transaction, user)
|
err = InsertTransactionTx(sqltx, &transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(AccountMissingError); ok {
|
if _, ok := err.(AccountMissingError); ok {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
@ -528,6 +513,15 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
}
|
}
|
||||||
|
sqltx.Rollback()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sqltx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
sqltx.Rollback()
|
||||||
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
|
log.Print(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -543,7 +537,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
//Return all Transactions
|
//Return all Transactions
|
||||||
var al TransactionList
|
var al TransactionList
|
||||||
transactions, err := GetTransactions(user.UserId)
|
transactions, err := GetTransactions(db, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -558,7 +552,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//Return Transaction with this Id
|
//Return Transaction with this Id
|
||||||
transaction, err := GetTransaction(transactionid, user.UserId)
|
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
@ -591,27 +585,46 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
transaction.UserId = user.UserId
|
transaction.UserId = user.UserId
|
||||||
|
|
||||||
balanced, err := transaction.Balanced()
|
sqltx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
balanced, err := transaction.Balanced(sqltx)
|
||||||
|
if err != nil {
|
||||||
|
sqltx.Rollback()
|
||||||
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
|
log.Print(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
if !transaction.Valid() || !balanced {
|
if !transaction.Valid() || !balanced {
|
||||||
|
sqltx.Rollback()
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range transaction.Splits {
|
for i := range transaction.Splits {
|
||||||
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
sqltx.Rollback()
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateTransaction(&transaction, user)
|
err = UpdateTransactionTx(sqltx, &transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
sqltx.Rollback()
|
||||||
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
|
log.Print(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sqltx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
sqltx.Rollback()
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return
|
||||||
@ -630,13 +643,13 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
transaction, err := GetTransaction(transactionid, user.UserId)
|
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
WriteError(w, 3 /*Invalid Request*/)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteTransaction(transaction, user)
|
err = DeleteTransaction(db, transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -672,9 +685,9 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6
|
|||||||
return &pageDifference, nil
|
return &pageDifference, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
|
func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -707,9 +720,9 @@ func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assumes accountid is valid and is owned by the current user
|
// Assumes accountid is valid and is owned by the current user
|
||||||
func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -741,9 +754,9 @@ func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.R
|
|||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -775,11 +788,11 @@ func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Ti
|
|||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
||||||
var transactions []Transaction
|
var transactions []Transaction
|
||||||
var atl AccountTransactionsList
|
var atl AccountTransactionsList
|
||||||
|
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -878,7 +891,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6
|
|||||||
|
|
||||||
// Return only those transactions which have at least one split pertaining to
|
// Return only those transactions which have at least one split pertaining to
|
||||||
// an account
|
// an account
|
||||||
func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
|
func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
|
||||||
user *User, accountid int64) {
|
user *User, accountid int64) {
|
||||||
|
|
||||||
var page uint64 = 0
|
var page uint64 = 0
|
||||||
@ -916,7 +929,7 @@ func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
|
|||||||
sort = sortstring
|
sort = sortstring
|
||||||
}
|
}
|
||||||
|
|
||||||
accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit)
|
accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
32
users.go
32
users.go
@ -47,10 +47,10 @@ func (u *User) HashPassword() {
|
|||||||
u.Password = ""
|
u.Password = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(userid int64) (*User, error) {
|
func GetUser(db *DB, userid int64) (*User, error) {
|
||||||
var u User
|
var u User
|
||||||
|
|
||||||
err := DB.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -67,18 +67,18 @@ func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) {
|
|||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByUsername(username string) (*User, error) {
|
func GetUserByUsername(db *DB, username string) (*User, error) {
|
||||||
var u User
|
var u User
|
||||||
|
|
||||||
err := DB.SelectOne(&u, "SELECT * from users where Username=?", username)
|
err := db.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertUser(u *User) error {
|
func InsertUser(db *DB, u *User) error {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -136,16 +136,16 @@ func InsertUser(u *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserFromSession(r *http.Request) (*User, error) {
|
func GetUserFromSession(db *DB, r *http.Request) (*User, error) {
|
||||||
s, err := GetSession(r)
|
s, err := GetSession(db, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return GetUser(s.UserId)
|
return GetUser(db, s.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(u *User) error {
|
func UpdateUser(db *DB, u *User) error {
|
||||||
transaction, err := DB.Begin()
|
transaction, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -180,7 +180,7 @@ func UpdateUser(u *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserHandler(w http.ResponseWriter, r *http.Request) {
|
func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||||
if r.Method == "POST" {
|
if r.Method == "POST" {
|
||||||
user_json := r.PostFormValue("user")
|
user_json := r.PostFormValue("user")
|
||||||
if user_json == "" {
|
if user_json == "" {
|
||||||
@ -197,7 +197,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
user.UserId = -1
|
user.UserId = -1
|
||||||
user.HashPassword()
|
user.HashPassword()
|
||||||
|
|
||||||
err = InsertUser(&user)
|
err = InsertUser(db, &user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(UserExistsError); ok {
|
if _, ok := err.(UserExistsError); ok {
|
||||||
WriteError(w, 4 /*User Exists*/)
|
WriteError(w, 4 /*User Exists*/)
|
||||||
@ -216,7 +216,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
user, err := GetUserFromSession(r)
|
user, err := GetUserFromSession(db, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
WriteError(w, 1 /*Not Signed In*/)
|
||||||
return
|
return
|
||||||
@ -264,7 +264,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
user.PasswordHash = old_pwhash
|
user.PasswordHash = old_pwhash
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateUser(user)
|
err = UpdateUser(db, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -278,7 +278,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
count, err := DB.Delete(&user)
|
count, err := db.Delete(&user)
|
||||||
if count != 1 || err != nil {
|
if count != 1 || err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
Loading…
Reference in New Issue
Block a user