Merge pull request #20 from aclindsa/db_updates

Db updates
This commit is contained in:
Aaron Lindsay 2017-10-04 08:08:12 -04:00 committed by GitHub
commit d9ddef250a
13 changed files with 267 additions and 208 deletions

View File

@ -124,10 +124,10 @@ func (al *AccountList) Write(w http.ResponseWriter) error {
return enc.Encode(al)
}
func GetAccount(accountid int64, userid int64) (*Account, error) {
func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) {
var a Account
err := DB.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
if err != nil {
return nil, err
}
@ -145,10 +145,10 @@ func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64)
return &a, nil
}
func GetAccounts(userid int64) (*[]Account, error) {
func GetAccounts(db *DB, userid int64) (*[]Account, error) {
var accounts []Account
_, err := DB.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
_, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
if err != nil {
return nil, err
}
@ -276,8 +276,8 @@ func (pame ParentAccountMissingError) Error() string {
return "Parent account missing"
}
func insertUpdateAccount(a *Account, insert bool) error {
transaction, err := DB.Begin()
func insertUpdateAccount(db *DB, a *Account, insert bool) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -329,16 +329,16 @@ func insertUpdateAccount(a *Account, insert bool) error {
return nil
}
func InsertAccount(a *Account) error {
return insertUpdateAccount(a, true)
func InsertAccount(db *DB, a *Account) error {
return insertUpdateAccount(db, a, true)
}
func UpdateAccount(a *Account) error {
return insertUpdateAccount(a, false)
func UpdateAccount(db *DB, a *Account) error {
return insertUpdateAccount(db, a, false)
}
func DeleteAccount(a *Account) error {
transaction, err := DB.Begin()
func DeleteAccount(db *DB, a *Account) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -385,8 +385,8 @@ func DeleteAccount(a *Account) error {
return nil
}
func AccountHandler(w http.ResponseWriter, r *http.Request) {
user, err := GetUserFromSession(r)
func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(db, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
@ -405,7 +405,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
log.Print(err)
return
}
AccountImportHandler(w, r, user, accountid, importtype)
AccountImportHandler(db, w, r, user, accountid, importtype)
return
}
@ -425,7 +425,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
account.UserId = user.UserId
account.AccountVersion = 0
security, err := GetSecurity(account.SecurityId, user.UserId)
security, err := GetSecurity(db, account.SecurityId, user.UserId)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -436,7 +436,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
return
}
err = InsertAccount(&account)
err = InsertAccount(db, &account)
if err != nil {
if _, ok := err.(ParentAccountMissingError); ok {
WriteError(w, 3 /*Invalid Request*/)
@ -461,7 +461,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
if err != nil || n != 1 {
//Return all Accounts
var al AccountList
accounts, err := GetAccounts(user.UserId)
accounts, err := GetAccounts(db, user.UserId)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -478,12 +478,12 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
// if URL looks like /account/[0-9]+/transactions, use the account
// transaction handler
if accountTransactionsRE.MatchString(r.URL.Path) {
AccountTransactionsHandler(w, r, user, accountid)
AccountTransactionsHandler(db, w, r, user, accountid)
return
}
// Return Account with this Id
account, err := GetAccount(accountid, user.UserId)
account, err := GetAccount(db, accountid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
@ -517,7 +517,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
}
account.UserId = user.UserId
security, err := GetSecurity(account.SecurityId, user.UserId)
security, err := GetSecurity(db, account.SecurityId, user.UserId)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -528,7 +528,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
return
}
err = UpdateAccount(&account)
err = UpdateAccount(db, &account)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -542,13 +542,13 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
return
}
} else if r.Method == "DELETE" {
account, err := GetAccount(accountid, user.UserId)
account, err := GetAccount(db, accountid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
}
err = DeleteAccount(account)
err = DeleteAccount(db, account)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)

View File

@ -15,14 +15,19 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
ctx := L.Context()
account_map, ok := ctx.Value(accountsContextKey).(map[int64]*Account)
db, ok := ctx.Value(dbContextKey).(*DB)
if !ok {
return nil, errors.New("Couldn't find DB in lua's Context")
}
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
if !ok {
user, ok := ctx.Value(userContextKey).(*User)
if !ok {
return nil, errors.New("Couldn't find User in lua's Context")
}
accounts, err := GetAccounts(user.UserId)
accounts, err := GetAccounts(db, user.UserId)
if err != nil {
return nil, err
}
@ -144,6 +149,10 @@ func luaAccountBalance(L *lua.LState) int {
a := luaCheckAccount(L, 1)
ctx := L.Context()
db, ok := ctx.Value(dbContextKey).(*DB)
if !ok {
panic("Couldn't find DB in lua's Context")
}
user, ok := ctx.Value(userContextKey).(*User)
if !ok {
panic("Couldn't find User in lua's Context")
@ -162,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int {
if date != nil {
end := luaWeakCheckTime(L, 3)
if end != nil {
rat, err = GetAccountBalanceDateRange(user, a.AccountId, date, end)
rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end)
} else {
rat, err = GetAccountBalanceDate(user, a.AccountId, date)
rat, err = GetAccountBalanceDate(db, user, a.AccountId, date)
}
} else {
rat, err = GetAccountBalance(user, a.AccountId)
rat, err = GetAccountBalance(db, user, a.AccountId)
}
if err != nil {
panic("Failed to GetAccountBalance:" + err.Error())

28
db.go
View File

@ -2,22 +2,34 @@ package main
import (
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"gopkg.in/gorp.v1"
"log"
)
var DB *gorp.DbMap
func initDB(cfg *Config) {
func initDB(cfg *Config) (*gorp.DbMap, error) {
db, err := sql.Open(cfg.MoneyGo.DBType.String(), cfg.MoneyGo.DSN)
if err != nil {
log.Fatal(err)
return nil, err
}
dbmap := &gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}}
var dialect gorp.Dialect
if cfg.MoneyGo.DBType == SQLite {
dialect = gorp.SqliteDialect{}
} else if cfg.MoneyGo.DBType == MySQL {
dialect = gorp.MySQLDialect{
Engine: "InnoDB",
Encoding: "UTF8",
}
} else if cfg.MoneyGo.DBType == Postgres {
dialect = gorp.PostgresDialect{}
} else {
return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String())
}
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
dbmap.AddTableWithName(User{}, "users").SetKeys(true, "UserId")
dbmap.AddTableWithName(Session{}, "sessions").SetKeys(true, "SessionId")
dbmap.AddTableWithName(Account{}, "accounts").SetKeys(true, "AccountId")
@ -29,8 +41,8 @@ func initDB(cfg *Config) {
err = dbmap.CreateTablesIfNotExists()
if err != nil {
log.Fatal(err)
return nil, err
}
DB = dbmap
return dbmap, nil
}

View File

@ -308,8 +308,8 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) {
return &gncimport, nil
}
func GnucashImportHandler(w http.ResponseWriter, r *http.Request) {
user, err := GetUserFromSession(r)
func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(db, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
@ -365,7 +365,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request) {
return
}
sqltransaction, err := DB.Begin()
sqltransaction, err := db.Begin()
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)

View File

@ -22,7 +22,7 @@ func (od *OFXDownload) Read(json_str string) error {
return dec.Decode(od)
}
func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
itl, err := ImportOFX(r)
if err != nil {
@ -38,7 +38,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i
return
}
sqltransaction, err := DB.Begin()
sqltransaction, err := db.Begin()
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -258,7 +258,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i
WriteSuccess(w)
}
func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
download_json := r.PostFormValue("ofxdownload")
if download_json == "" {
log.Print("download_json")
@ -274,7 +274,7 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun
return
}
account, err := GetAccount(accountid, user.UserId)
account, err := GetAccount(db, accountid, user.UserId)
if err != nil {
log.Print("GetAccount")
WriteError(w, 3 /*Invalid Request*/)
@ -367,10 +367,10 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun
}
defer response.Body.Close()
ofxImportHelper(response.Body, w, user, accountid)
ofxImportHelper(db, response.Body, w, user, accountid)
}
func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
multipartReader, err := r.MultipartReader()
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
@ -390,19 +390,19 @@ func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac
return
}
ofxImportHelper(part, w, user, accountid)
ofxImportHelper(db, part, w, user, accountid)
}
/*
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
*/
func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
switch importtype {
case "ofx":
OFXImportHandler(w, r, user, accountid)
OFXImportHandler(db, w, r, user, accountid)
case "ofxfile":
OFXFileImportHandler(w, r, user, accountid)
OFXFileImportHandler(db, w, r, user, accountid)
default:
WriteError(w, 3 /*Invalid Request*/)
}

44
main.go
View File

@ -4,6 +4,7 @@ package main
import (
"flag"
"gopkg.in/gorp.v1"
"log"
"net"
"net/http"
@ -43,8 +44,6 @@ func init() {
// Setup the logging flags to be printed
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
initDB(config)
}
func rootHandler(w http.ResponseWriter, r *http.Request) {
@ -55,18 +54,39 @@ func staticHandler(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path.Join(config.MoneyGo.Basedir, r.URL.Path))
}
func main() {
// Create a closure over db, allowing the handlers to look like a
// http.HandlerFunc
type DB = gorp.DbMap
type DBHandler func(http.ResponseWriter, *http.Request, *DB)
func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h(w, r, db)
}
}
func GetHandler(db *DB) http.Handler {
servemux := http.NewServeMux()
servemux.HandleFunc("/", rootHandler)
servemux.HandleFunc("/static/", staticHandler)
servemux.HandleFunc("/session/", SessionHandler)
servemux.HandleFunc("/user/", UserHandler)
servemux.HandleFunc("/security/", SecurityHandler)
servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db))
servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db))
servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db))
servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler)
servemux.HandleFunc("/account/", AccountHandler)
servemux.HandleFunc("/transaction/", TransactionHandler)
servemux.HandleFunc("/import/gnucash", GnucashImportHandler)
servemux.HandleFunc("/report/", ReportHandler)
servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db))
servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db))
servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db))
servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db))
return servemux
}
func main() {
database, err := initDB(config)
if err != nil {
log.Fatal(err)
}
handler := GetHandler(database)
listener, err := net.Listen("tcp", ":"+strconv.Itoa(config.MoneyGo.Port))
if err != nil {
@ -75,8 +95,8 @@ func main() {
log.Printf("Serving on port %d out of directory: %s", config.MoneyGo.Port, config.MoneyGo.Basedir)
if config.MoneyGo.Fcgi {
fcgi.Serve(listener, servemux)
fcgi.Serve(listener, handler)
} else {
http.Serve(listener, servemux)
http.Serve(listener, handler)
}
}

View File

@ -1,8 +1,6 @@
package main
import (
"fmt"
"github.com/FlashBoys/go-finance"
"gopkg.in/gorp.v1"
"time"
)
@ -93,8 +91,8 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi
}
}
func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, error) {
transaction, err := DB.Begin()
func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) {
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -113,10 +111,3 @@ func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, err
return price, nil
}
func init() {
q, err := finance.GetQuote("BRK-A")
if err == nil {
fmt.Printf("%+v", q)
}
}

View File

@ -27,6 +27,7 @@ const (
accountsContextKey
securitiesContextKey
balanceContextKey
dbContextKey
)
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
@ -76,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error {
return enc.Encode(r)
}
func GetReport(reportid int64, userid int64) (*Report, error) {
func GetReport(db *DB, reportid int64, userid int64) (*Report, error) {
var r Report
err := DB.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
if err != nil {
return nil, err
}
return &r, nil
}
func GetReports(userid int64) (*[]Report, error) {
func GetReports(db *DB, userid int64) (*[]Report, error) {
var reports []Report
_, err := DB.Select(&reports, "SELECT * from reports where UserId=?", userid)
_, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid)
if err != nil {
return nil, err
}
return &reports, nil
}
func InsertReport(r *Report) error {
err := DB.Insert(r)
func InsertReport(db *DB, r *Report) error {
err := db.Insert(r)
if err != nil {
return err
}
return nil
}
func UpdateReport(r *Report) error {
count, err := DB.Update(r)
func UpdateReport(db *DB, r *Report) error {
count, err := db.Update(r)
if err != nil {
return err
}
@ -115,8 +116,8 @@ func UpdateReport(r *Report) error {
return nil
}
func DeleteReport(r *Report) error {
count, err := DB.Delete(r)
func DeleteReport(db *DB, r *Report) error {
count, err := db.Delete(r)
if err != nil {
return err
}
@ -126,13 +127,14 @@ func DeleteReport(r *Report) error {
return nil
}
func runReport(user *User, report *Report) (*Tabulation, error) {
func runReport(db *DB, user *User, report *Report) (*Tabulation, error) {
// Create a new LState without opening the default libs for security
L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close()
// Create a new context holding the current user with a timeout
ctx := context.WithValue(context.Background(), userContextKey, user)
ctx = context.WithValue(ctx, dbContextKey, db)
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
defer cancel()
L.SetContext(ctx)
@ -189,14 +191,14 @@ func runReport(user *User, report *Report) (*Tabulation, error) {
}
}
func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
report, err := GetReport(reportid, user.UserId)
func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
report, err := GetReport(db, reportid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
}
tabulation, err := runReport(user, report)
tabulation, err := runReport(db, user, report)
if err != nil {
// TODO handle different failure cases differently
log.Print("runReport returned:", err)
@ -214,8 +216,8 @@ func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User,
}
}
func ReportHandler(w http.ResponseWriter, r *http.Request) {
user, err := GetUserFromSession(r)
func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(db, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
@ -237,7 +239,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
report.ReportId = -1
report.UserId = user.UserId
err = InsertReport(&report)
err = InsertReport(db, &report)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -260,7 +262,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
log.Print(err)
return
}
ReportTabulationHandler(w, r, user, reportid)
ReportTabulationHandler(db, w, r, user, reportid)
return
}
@ -269,7 +271,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
if err != nil || n != 1 {
//Return all Reports
var rl ReportList
reports, err := GetReports(user.UserId)
reports, err := GetReports(db, user.UserId)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -284,7 +286,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
}
} else {
// Return Report with this Id
report, err := GetReport(reportid, user.UserId)
report, err := GetReport(db, reportid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
@ -319,7 +321,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
}
report.UserId = user.UserId
err = UpdateReport(&report)
err = UpdateReport(db, &report)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -333,13 +335,13 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
return
}
} else if r.Method == "DELETE" {
report, err := GetReport(reportid, user.UserId)
report, err := GetReport(db, reportid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
}
err = DeleteReport(report)
err = DeleteReport(db, report)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)

View File

@ -96,10 +96,10 @@ func FindCurrencyTemplate(iso4217 int64) *Security {
return nil
}
func GetSecurity(securityid int64, userid int64) (*Security, error) {
func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) {
var s Security
err := DB.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
if err != nil {
return nil, err
}
@ -116,18 +116,18 @@ func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64
return &s, nil
}
func GetSecurities(userid int64) (*[]*Security, error) {
func GetSecurities(db *DB, userid int64) (*[]*Security, error) {
var securities []*Security
_, err := DB.Select(&securities, "SELECT * from securities where UserId=?", userid)
_, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid)
if err != nil {
return nil, err
}
return &securities, nil
}
func InsertSecurity(s *Security) error {
err := DB.Insert(s)
func InsertSecurity(db *DB, s *Security) error {
err := db.Insert(s)
if err != nil {
return err
}
@ -142,8 +142,8 @@ func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error {
return nil
}
func UpdateSecurity(s *Security) error {
transaction, err := DB.Begin()
func UpdateSecurity(db *DB, s *Security) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -176,8 +176,8 @@ func UpdateSecurity(s *Security) error {
return nil
}
func DeleteSecurity(s *Security) error {
transaction, err := DB.Begin()
func DeleteSecurity(db *DB, s *Security) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -279,8 +279,8 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
return security, nil
}
func SecurityHandler(w http.ResponseWriter, r *http.Request) {
user, err := GetUserFromSession(r)
func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(db, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
@ -302,7 +302,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
security.SecurityId = -1
security.UserId = user.UserId
err = InsertSecurity(&security)
err = InsertSecurity(db, &security)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -324,7 +324,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
//Return all securities
var sl SecurityList
securities, err := GetSecurities(user.UserId)
securities, err := GetSecurities(db, user.UserId)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -339,7 +339,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
return
}
} else {
security, err := GetSecurity(securityid, user.UserId)
security, err := GetSecurity(db, securityid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
@ -373,7 +373,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
}
security.UserId = user.UserId
err = UpdateSecurity(&security)
err = UpdateSecurity(db, &security)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -387,13 +387,13 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
return
}
} else if r.Method == "DELETE" {
security, err := GetSecurity(securityid, user.UserId)
security, err := GetSecurity(db, securityid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
}
err = DeleteSecurity(security)
err = DeleteSecurity(db, security)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)

View File

@ -13,14 +13,19 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
ctx := L.Context()
security_map, ok := ctx.Value(securitiesContextKey).(map[int64]*Security)
db, ok := ctx.Value(dbContextKey).(*DB)
if !ok {
return nil, errors.New("Couldn't find DB in lua's Context")
}
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
if !ok {
user, ok := ctx.Value(userContextKey).(*User)
if !ok {
return nil, errors.New("Couldn't find User in lua's Context")
}
securities, err := GetSecurities(user.UserId)
securities, err := GetSecurities(db, user.UserId)
if err != nil {
return nil, err
}
@ -149,7 +154,13 @@ func luaClosestPrice(L *lua.LState) int {
c := luaCheckSecurity(L, 2)
date := luaCheckTime(L, 3)
p, err := GetClosestPrice(s, c, date)
ctx := L.Context()
db, ok := ctx.Value(dbContextKey).(*DB)
if !ok {
panic("Couldn't find DB in lua's Context")
}
p, err := GetClosestPrice(db, s, c, date)
if err != nil {
L.Push(lua.LNil)
} else {

View File

@ -22,7 +22,7 @@ func (s *Session) Write(w http.ResponseWriter) error {
return enc.Encode(s)
}
func GetSession(r *http.Request) (*Session, error) {
func GetSession(db *DB, r *http.Request) (*Session, error) {
var s Session
cookie, err := r.Cookie("moneygo-session")
@ -31,17 +31,18 @@ func GetSession(r *http.Request) (*Session, error) {
}
s.SessionSecret = cookie.Value
err = DB.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
if err != nil {
return nil, err
}
return &s, nil
}
func DeleteSessionIfExists(r *http.Request) {
session, err := GetSession(r)
func DeleteSessionIfExists(db *DB, r *http.Request) {
// TODO do this in one transaction
session, err := GetSession(db, r)
if err == nil {
DB.Delete(session)
db.Delete(session)
}
}
@ -53,7 +54,7 @@ func NewSessionCookie() (string, error) {
return base64.StdEncoding.EncodeToString(bits), nil
}
func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
s := Session{}
session_secret, err := NewSessionCookie()
@ -75,14 +76,14 @@ func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session,
s.SessionSecret = session_secret
s.UserId = userid
err = DB.Insert(&s)
err = db.Insert(&s)
if err != nil {
return nil, err
}
return &s, nil
}
func SessionHandler(w http.ResponseWriter, r *http.Request) {
func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
if r.Method == "POST" || r.Method == "PUT" {
user_json := r.PostFormValue("user")
if user_json == "" {
@ -97,7 +98,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
return
}
dbuser, err := GetUserByUsername(user.Username)
dbuser, err := GetUserByUsername(db, user.Username)
if err != nil {
WriteError(w, 2 /*Unauthorized Access*/)
return
@ -109,9 +110,9 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
return
}
DeleteSessionIfExists(r)
DeleteSessionIfExists(db, r)
session, err := NewSession(w, r, dbuser.UserId)
session, err := NewSession(db, w, r, dbuser.UserId)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
return
@ -124,7 +125,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
return
}
} else if r.Method == "GET" {
s, err := GetSession(r)
s, err := GetSession(db, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
@ -132,7 +133,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
s.Write(w)
} else if r.Method == "DELETE" {
DeleteSessionIfExists(r)
DeleteSessionIfExists(db, r)
WriteSuccess(w)
}
}

View File

@ -146,11 +146,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
if t.Splits[i].AccountId != -1 {
var err error
var account *Account
if transaction != nil {
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
} else {
account, err = GetAccount(t.Splits[i].AccountId, t.UserId)
}
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
if err != nil {
return nil, err
}
@ -164,16 +160,12 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
return sums, nil
}
func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
return t.GetImbalancesTx(nil)
}
// Returns true if all securities contained in this transaction are balanced,
// false otherwise
func (t *Transaction) Balanced() (bool, error) {
func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
var zero big.Rat
sums, err := t.GetImbalances()
sums, err := t.GetImbalancesTx(transaction)
if err != nil {
return false, err
}
@ -186,21 +178,23 @@ func (t *Transaction) Balanced() (bool, error) {
return true, nil
}
func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) {
var t Transaction
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
if err != nil {
transaction.Rollback()
return nil, err
}
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
if err != nil {
transaction.Rollback()
return nil, err
}
@ -213,10 +207,10 @@ func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
return &t, nil
}
func GetTransactions(userid int64) (*[]Transaction, error) {
func GetTransactions(db *DB, userid int64) (*[]Transaction, error) {
var transactions []Transaction
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -316,8 +310,8 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
return nil
}
func InsertTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
func InsertTransaction(db *DB, t *Transaction, user *User) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -337,17 +331,11 @@ func InsertTransaction(t *Transaction, user *User) error {
return nil
}
func UpdateTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
if err != nil {
return err
}
func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
var existing_splits []*Split
_, err = transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
if err != nil {
transaction.Rollback()
return err
}
@ -367,11 +355,9 @@ func UpdateTransaction(t *Transaction, user *User) error {
if ok {
count, err := transaction.Update(t.Splits[i])
if err != nil {
transaction.Rollback()
return err
}
if count != 1 {
transaction.Rollback()
return errors.New("Updated more than one transaction split")
}
delete(s_map, t.Splits[i].SplitId)
@ -379,7 +365,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
t.Splits[i].SplitId = -1
err := transaction.Insert(t.Splits[i])
if err != nil {
transaction.Rollback()
return err
}
}
@ -397,7 +382,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
if ok {
_, err := transaction.Delete(existing_splits[i])
if err != nil {
transaction.Rollback()
return err
}
}
@ -410,31 +394,22 @@ func UpdateTransaction(t *Transaction, user *User) error {
}
err = incrementAccountVersions(transaction, user, a_ids)
if err != nil {
transaction.Rollback()
return err
}
count, err := transaction.Update(t)
if err != nil {
transaction.Rollback()
return err
}
if count != 1 {
transaction.Rollback()
return errors.New("Updated more than one transaction")
}
err = transaction.Commit()
if err != nil {
transaction.Rollback()
return err
}
return nil
}
func DeleteTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
func DeleteTransaction(db *DB, t *Transaction, user *User) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -477,8 +452,8 @@ func DeleteTransaction(t *Transaction, user *User) error {
return nil
}
func TransactionHandler(w http.ResponseWriter, r *http.Request) {
user, err := GetUserFromSession(r)
func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(db, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
@ -500,27 +475,37 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
transaction.TransactionId = -1
transaction.UserId = user.UserId
balanced, err := transaction.Balanced()
sqltx, err := db.Begin()
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
balanced, err := transaction.Balanced(sqltx)
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
if !transaction.Valid() || !balanced {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
for i := range transaction.Splits {
transaction.Splits[i].SplitId = -1
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
if err != nil {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
}
err = InsertTransaction(&transaction, user)
err = InsertTransactionTx(sqltx, &transaction, user)
if err != nil {
if _, ok := err.(AccountMissingError); ok {
WriteError(w, 3 /*Invalid Request*/)
@ -528,6 +513,15 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
}
sqltx.Rollback()
return
}
err = sqltx.Commit()
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
@ -543,7 +537,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
//Return all Transactions
var al TransactionList
transactions, err := GetTransactions(user.UserId)
transactions, err := GetTransactions(db, user.UserId)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -558,7 +552,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
}
} else {
//Return Transaction with this Id
transaction, err := GetTransaction(transactionid, user.UserId)
transaction, err := GetTransaction(db, transactionid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
@ -591,27 +585,46 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
}
transaction.UserId = user.UserId
balanced, err := transaction.Balanced()
sqltx, err := db.Begin()
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
balanced, err := transaction.Balanced(sqltx)
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
if !transaction.Valid() || !balanced {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
for i := range transaction.Splits {
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
if err != nil {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
}
err = UpdateTransaction(&transaction, user)
err = UpdateTransactionTx(sqltx, &transaction, user)
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
err = sqltx.Commit()
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
@ -630,13 +643,13 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
return
}
transaction, err := GetTransaction(transactionid, user.UserId)
transaction, err := GetTransaction(db, transactionid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
}
err = DeleteTransaction(transaction, user)
err = DeleteTransaction(db, transaction, user)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -672,9 +685,9 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6
return &pageDifference, nil
}
func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
var splits []Split
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -707,9 +720,9 @@ func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
}
// Assumes accountid is valid and is owned by the current user
func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.Rat, error) {
func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
var splits []Split
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -741,9 +754,9 @@ func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.R
return &balance, nil
}
func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
var splits []Split
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -775,11 +788,11 @@ func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Ti
return &balance, nil
}
func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
var transactions []Transaction
var atl AccountTransactionsList
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -878,7 +891,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6
// Return only those transactions which have at least one split pertaining to
// an account
func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
user *User, accountid int64) {
var page uint64 = 0
@ -916,7 +929,7 @@ func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
sort = sortstring
}
accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit)
accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)

View File

@ -47,10 +47,10 @@ func (u *User) HashPassword() {
u.Password = ""
}
func GetUser(userid int64) (*User, error) {
func GetUser(db *DB, userid int64) (*User, error) {
var u User
err := DB.SelectOne(&u, "SELECT * from users where UserId=?", userid)
err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid)
if err != nil {
return nil, err
}
@ -67,18 +67,18 @@ func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) {
return &u, nil
}
func GetUserByUsername(username string) (*User, error) {
func GetUserByUsername(db *DB, username string) (*User, error) {
var u User
err := DB.SelectOne(&u, "SELECT * from users where Username=?", username)
err := db.SelectOne(&u, "SELECT * from users where Username=?", username)
if err != nil {
return nil, err
}
return &u, nil
}
func InsertUser(u *User) error {
transaction, err := DB.Begin()
func InsertUser(db *DB, u *User) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -136,16 +136,16 @@ func InsertUser(u *User) error {
return nil
}
func GetUserFromSession(r *http.Request) (*User, error) {
s, err := GetSession(r)
func GetUserFromSession(db *DB, r *http.Request) (*User, error) {
s, err := GetSession(db, r)
if err != nil {
return nil, err
}
return GetUser(s.UserId)
return GetUser(db, s.UserId)
}
func UpdateUser(u *User) error {
transaction, err := DB.Begin()
func UpdateUser(db *DB, u *User) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -180,7 +180,7 @@ func UpdateUser(u *User) error {
return nil
}
func UserHandler(w http.ResponseWriter, r *http.Request) {
func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
if r.Method == "POST" {
user_json := r.PostFormValue("user")
if user_json == "" {
@ -197,7 +197,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
user.UserId = -1
user.HashPassword()
err = InsertUser(&user)
err = InsertUser(db, &user)
if err != nil {
if _, ok := err.(UserExistsError); ok {
WriteError(w, 4 /*User Exists*/)
@ -216,7 +216,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
return
}
} else {
user, err := GetUserFromSession(r)
user, err := GetUserFromSession(db, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
@ -264,7 +264,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
user.PasswordHash = old_pwhash
}
err = UpdateUser(user)
err = UpdateUser(db, user)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -278,7 +278,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
return
}
} else if r.Method == "DELETE" {
count, err := DB.Delete(&user)
count, err := db.Delete(&user)
if count != 1 || err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)