mirror of
https://github.com/aclindsa/moneygo.git
synced 2025-06-13 13:39:23 -04:00
First pass at reorganizing go code into sub-packages
This commit is contained in:
78
internal/config/config.go
Normal file
78
internal/config/config.go
Normal file
@ -0,0 +1,78 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gopkg.in/gcfg.v1"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DbType uint
|
||||
|
||||
const (
|
||||
SQLite DbType = 1 + iota
|
||||
MySQL
|
||||
Postgres
|
||||
)
|
||||
|
||||
var dbTypes = [...]string{"sqlite3", "mysql", "postgres"}
|
||||
|
||||
func (e DbType) Valid() bool {
|
||||
// This check is mostly out of paranoia, ensuring e != 0 should be
|
||||
// sufficient
|
||||
return e >= SQLite && e <= Postgres
|
||||
}
|
||||
|
||||
func (e DbType) String() string {
|
||||
if e.Valid() {
|
||||
return dbTypes[e-1]
|
||||
}
|
||||
return fmt.Sprintf("invalid DbType (%d)", e)
|
||||
}
|
||||
|
||||
func (e *DbType) FromString(in string) error {
|
||||
value := strings.TrimSpace(in)
|
||||
|
||||
for i, s := range dbTypes {
|
||||
if s == value {
|
||||
*e = DbType(i + 1)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
*e = 0
|
||||
return errors.New("Invalid DbType: \"" + in + "\"")
|
||||
}
|
||||
|
||||
func (e *DbType) UnmarshalText(text []byte) error {
|
||||
return e.FromString(string(text))
|
||||
}
|
||||
|
||||
type MoneyGo struct {
|
||||
Fcgi bool // whether to serve FCGI (HTTP by default if false)
|
||||
Port int // port to serve API/files on
|
||||
Basedir string `gcfg:"base-directory"` // base directory for serving files out of
|
||||
DBType DbType `gcfg:"db-type"` // Whether this is a sqlite/mysql/postgresql database
|
||||
DSN string `gcfg:"db-dsn"` // 'Data Source Name' for database connection
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
MoneyGo MoneyGo
|
||||
}
|
||||
|
||||
func ReadConfig(filename string) (*Config, error) {
|
||||
cfg := Config{
|
||||
MoneyGo: MoneyGo{
|
||||
Fcgi: false,
|
||||
Port: 80,
|
||||
Basedir: "src/github.com/aclindsa/moneygo/",
|
||||
DBType: SQLite,
|
||||
DSN: "file:moneygo.sqlite?cache=shared&mode=rwc",
|
||||
},
|
||||
}
|
||||
|
||||
err := gcfg.ReadFileInto(&cfg, filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse config file: %s", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
45
internal/db/db.go
Normal file
45
internal/db/db.go
Normal file
@ -0,0 +1,45 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/config"
|
||||
"github.com/aclindsa/moneygo/internal/handlers"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gopkg.in/gorp.v1"
|
||||
)
|
||||
|
||||
func GetDbMap(db *sql.DB, cfg *config.Config) (*gorp.DbMap, error) {
|
||||
var dialect gorp.Dialect
|
||||
if cfg.MoneyGo.DBType == config.SQLite {
|
||||
dialect = gorp.SqliteDialect{}
|
||||
} else if cfg.MoneyGo.DBType == config.MySQL {
|
||||
dialect = gorp.MySQLDialect{
|
||||
Engine: "InnoDB",
|
||||
Encoding: "UTF8",
|
||||
}
|
||||
} else if cfg.MoneyGo.DBType == config.Postgres {
|
||||
dialect = gorp.PostgresDialect{}
|
||||
} else {
|
||||
return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String())
|
||||
}
|
||||
|
||||
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
||||
dbmap.AddTableWithName(handlers.User{}, "users").SetKeys(true, "UserId")
|
||||
dbmap.AddTableWithName(handlers.Session{}, "sessions").SetKeys(true, "SessionId")
|
||||
dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId")
|
||||
dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId")
|
||||
dbmap.AddTableWithName(handlers.Transaction{}, "transactions").SetKeys(true, "TransactionId")
|
||||
dbmap.AddTableWithName(handlers.Split{}, "splits").SetKeys(true, "SplitId")
|
||||
dbmap.AddTableWithName(handlers.Price{}, "prices").SetKeys(true, "PriceId")
|
||||
dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId")
|
||||
|
||||
err := dbmap.CreateTablesIfNotExists()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dbmap, nil
|
||||
}
|
9
internal/handlers/Makefile
Normal file
9
internal/handlers/Makefile
Normal file
@ -0,0 +1,9 @@
|
||||
all: security_templates.go
|
||||
|
||||
security_templates.go: cusip_list.csv scripts/gen_security_list.py
|
||||
./scripts/gen_security_list.py > security_templates.go
|
||||
|
||||
cusip_list.csv:
|
||||
./scripts/gen_cusip_csv.sh > cusip_list.csv
|
||||
|
||||
.PHONY = all
|
561
internal/handlers/accounts.go
Normal file
561
internal/handlers/accounts.go
Normal file
@ -0,0 +1,561 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"gopkg.in/gorp.v1"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AccountType int64
|
||||
|
||||
const (
|
||||
Bank AccountType = 1 // start at 1 so that the default (0) is invalid
|
||||
Cash = 2
|
||||
Asset = 3
|
||||
Liability = 4
|
||||
Investment = 5
|
||||
Income = 6
|
||||
Expense = 7
|
||||
Trading = 8
|
||||
Equity = 9
|
||||
Receivable = 10
|
||||
Payable = 11
|
||||
)
|
||||
|
||||
var AccountTypes = []AccountType{
|
||||
Bank,
|
||||
Cash,
|
||||
Asset,
|
||||
Liability,
|
||||
Investment,
|
||||
Income,
|
||||
Expense,
|
||||
Trading,
|
||||
Equity,
|
||||
Receivable,
|
||||
Payable,
|
||||
}
|
||||
|
||||
func (t AccountType) String() string {
|
||||
switch t {
|
||||
case Bank:
|
||||
return "Bank"
|
||||
case Cash:
|
||||
return "Cash"
|
||||
case Asset:
|
||||
return "Asset"
|
||||
case Liability:
|
||||
return "Liability"
|
||||
case Investment:
|
||||
return "Investment"
|
||||
case Income:
|
||||
return "Income"
|
||||
case Expense:
|
||||
return "Expense"
|
||||
case Trading:
|
||||
return "Trading"
|
||||
case Equity:
|
||||
return "Equity"
|
||||
case Receivable:
|
||||
return "Receivable"
|
||||
case Payable:
|
||||
return "Payable"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
AccountId int64
|
||||
ExternalAccountId string
|
||||
UserId int64
|
||||
SecurityId int64
|
||||
ParentAccountId int64 // -1 if this account is at the root
|
||||
Type AccountType
|
||||
Name string
|
||||
|
||||
// monotonically-increasing account transaction version number. Used for
|
||||
// allowing a client to ensure they have a consistent version when paging
|
||||
// through transactions.
|
||||
AccountVersion int64 `json:"Version"`
|
||||
|
||||
// Optional fields specifying how to fetch transactions from a bank via OFX
|
||||
OFXURL string
|
||||
OFXORG string
|
||||
OFXFID string
|
||||
OFXUser string
|
||||
OFXBankID string // OFX BankID (BrokerID if AcctType == Investment)
|
||||
OFXAcctID string
|
||||
OFXAcctType string // ofxgo.acctType
|
||||
OFXClientUID string
|
||||
OFXAppID string
|
||||
OFXAppVer string
|
||||
OFXVersion string
|
||||
OFXNoIndent bool
|
||||
}
|
||||
|
||||
type AccountList struct {
|
||||
Accounts *[]Account `json:"accounts"`
|
||||
}
|
||||
|
||||
var accountTransactionsRE *regexp.Regexp
|
||||
var accountImportRE *regexp.Regexp
|
||||
|
||||
func init() {
|
||||
accountTransactionsRE = regexp.MustCompile(`^/account/[0-9]+/transactions/?$`)
|
||||
accountImportRE = regexp.MustCompile(`^/account/[0-9]+/import/[a-z]+/?$`)
|
||||
}
|
||||
|
||||
func (a *Account) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(a)
|
||||
}
|
||||
|
||||
func (a *Account) Read(json_str string) error {
|
||||
dec := json.NewDecoder(strings.NewReader(json_str))
|
||||
return dec.Decode(a)
|
||||
}
|
||||
|
||||
func (al *AccountList) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(al)
|
||||
}
|
||||
|
||||
func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) {
|
||||
var a Account
|
||||
|
||||
err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) (*Account, error) {
|
||||
var a Account
|
||||
|
||||
err := transaction.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
func GetAccounts(db *DB, userid int64) (*[]Account, error) {
|
||||
var accounts []Account
|
||||
|
||||
_, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &accounts, nil
|
||||
}
|
||||
|
||||
// Get (and attempt to create if it doesn't exist). Matches on UserId,
|
||||
// SecurityId, Type, Name, and ParentAccountId
|
||||
func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, error) {
|
||||
var accounts []Account
|
||||
var account Account
|
||||
|
||||
// Try to find the top-level trading account
|
||||
_, err := transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 1 {
|
||||
account = accounts[0]
|
||||
} else {
|
||||
account.UserId = a.UserId
|
||||
account.SecurityId = a.SecurityId
|
||||
account.Type = a.Type
|
||||
account.Name = a.Name
|
||||
account.ParentAccountId = a.ParentAccountId
|
||||
|
||||
err = transaction.Insert(&account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||
// trading account for the supplied security/currency
|
||||
func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) {
|
||||
var tradingAccount Account
|
||||
var account Account
|
||||
|
||||
user, err := GetUserTx(transaction, userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tradingAccount.UserId = userid
|
||||
tradingAccount.Type = Trading
|
||||
tradingAccount.Name = "Trading"
|
||||
tradingAccount.SecurityId = user.DefaultCurrency
|
||||
tradingAccount.ParentAccountId = -1
|
||||
|
||||
// Find/create the top-level trading account
|
||||
ta, err := GetCreateAccountTx(transaction, tradingAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
security, err := GetSecurityTx(transaction, securityid, userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.UserId = userid
|
||||
account.Name = security.Name
|
||||
account.ParentAccountId = ta.AccountId
|
||||
account.SecurityId = securityid
|
||||
account.Type = Trading
|
||||
|
||||
a, err := GetCreateAccountTx(transaction, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||
// imbalance account for the supplied security/currency
|
||||
func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) {
|
||||
var imbalanceAccount Account
|
||||
var account Account
|
||||
xxxtemplate := FindSecurityTemplate("XXX", Currency)
|
||||
if xxxtemplate == nil {
|
||||
return nil, errors.New("Couldn't find XXX security template")
|
||||
}
|
||||
xxxsecurity, err := ImportGetCreateSecurity(transaction, userid, xxxtemplate)
|
||||
if err != nil {
|
||||
return nil, errors.New("Couldn't create XXX security")
|
||||
}
|
||||
|
||||
imbalanceAccount.UserId = userid
|
||||
imbalanceAccount.Name = "Imbalances"
|
||||
imbalanceAccount.ParentAccountId = -1
|
||||
imbalanceAccount.SecurityId = xxxsecurity.SecurityId
|
||||
imbalanceAccount.Type = Bank
|
||||
|
||||
// Find/create the top-level trading account
|
||||
ia, err := GetCreateAccountTx(transaction, imbalanceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
security, err := GetSecurityTx(transaction, securityid, userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.UserId = userid
|
||||
account.Name = security.Name
|
||||
account.ParentAccountId = ia.AccountId
|
||||
account.SecurityId = securityid
|
||||
account.Type = Bank
|
||||
|
||||
a, err := GetCreateAccountTx(transaction, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
type ParentAccountMissingError struct{}
|
||||
|
||||
func (pame ParentAccountMissingError) Error() string {
|
||||
return "Parent account missing"
|
||||
}
|
||||
|
||||
func insertUpdateAccount(db *DB, a *Account, insert bool) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if a.ParentAccountId != -1 {
|
||||
existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", a.ParentAccountId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if existing != 1 {
|
||||
transaction.Rollback()
|
||||
return ParentAccountMissingError{}
|
||||
}
|
||||
}
|
||||
|
||||
if insert {
|
||||
err = transaction.Insert(a)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
oldacct, err := GetAccountTx(transaction, a.AccountId, a.UserId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
a.AccountVersion = oldacct.AccountVersion + 1
|
||||
|
||||
count, err := transaction.Update(a)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Updated more than one account")
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertAccount(db *DB, a *Account) error {
|
||||
return insertUpdateAccount(db, a, true)
|
||||
}
|
||||
|
||||
func UpdateAccount(db *DB, a *Account) error {
|
||||
return insertUpdateAccount(db, a, false)
|
||||
}
|
||||
|
||||
func DeleteAccount(db *DB, a *Account) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if a.ParentAccountId != -1 {
|
||||
// Re-parent splits to this account's parent account if this account isn't a root account
|
||||
_, err = transaction.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Delete splits if this account is a root account
|
||||
_, err = transaction.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Re-parent child accounts to this account's parent account
|
||||
_, err = transaction.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := transaction.Delete(a)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Was going to delete more than one account")
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "POST" {
|
||||
// if URL looks like /account/[0-9]+/import, use the account
|
||||
// import handler
|
||||
if accountImportRE.MatchString(r.URL.Path) {
|
||||
var accountid int64
|
||||
var importtype string
|
||||
n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype)
|
||||
|
||||
if err != nil || n != 2 {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
AccountImportHandler(db, w, r, user, accountid, importtype)
|
||||
return
|
||||
}
|
||||
|
||||
account_json := r.PostFormValue("account")
|
||||
if account_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var account Account
|
||||
err := account.Read(account_json)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
account.AccountId = -1
|
||||
account.UserId = user.UserId
|
||||
account.AccountVersion = 0
|
||||
|
||||
security, err := GetSecurity(db, account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
if security == nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = InsertAccount(db, &account)
|
||||
if err != nil {
|
||||
if _, ok := err.(ParentAccountMissingError); ok {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
} else {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(201 /*Created*/)
|
||||
err = account.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "GET" {
|
||||
var accountid int64
|
||||
n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid)
|
||||
|
||||
if err != nil || n != 1 {
|
||||
//Return all Accounts
|
||||
var al AccountList
|
||||
accounts, err := GetAccounts(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
al.Accounts = accounts
|
||||
err = (&al).Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// if URL looks like /account/[0-9]+/transactions, use the account
|
||||
// transaction handler
|
||||
if accountTransactionsRE.MatchString(r.URL.Path) {
|
||||
AccountTransactionsHandler(db, w, r, user, accountid)
|
||||
return
|
||||
}
|
||||
|
||||
// Return Account with this Id
|
||||
account, err := GetAccount(db, accountid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = account.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
accountid, err := GetURLID(r.URL.Path)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
if r.Method == "PUT" {
|
||||
account_json := r.PostFormValue("account")
|
||||
if account_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var account Account
|
||||
err := account.Read(account_json)
|
||||
if err != nil || account.AccountId != accountid {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
account.UserId = user.UserId
|
||||
|
||||
security, err := GetSecurity(db, account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
if security == nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = UpdateAccount(db, &account)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = account.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
account, err := GetAccount(db, accountid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteAccount(db, account)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
WriteSuccess(w)
|
||||
}
|
||||
}
|
||||
}
|
217
internal/handlers/accounts_lua.go
Normal file
217
internal/handlers/accounts_lua.go
Normal file
@ -0,0 +1,217 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/yuin/gopher-lua"
|
||||
"math/big"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const luaAccountTypeName = "account"
|
||||
|
||||
func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
|
||||
var account_map map[int64]*Account
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find DB in lua's Context")
|
||||
}
|
||||
|
||||
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
|
||||
if !ok {
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
accounts, err := GetAccounts(db, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account_map = make(map[int64]*Account)
|
||||
for i := range *accounts {
|
||||
account_map[(*accounts)[i].AccountId] = &(*accounts)[i]
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, accountsContextKey, account_map)
|
||||
L.SetContext(ctx)
|
||||
}
|
||||
|
||||
return account_map, nil
|
||||
}
|
||||
|
||||
func luaGetAccounts(L *lua.LState) int {
|
||||
account_map, err := luaContextGetAccounts(L)
|
||||
if err != nil {
|
||||
panic("luaGetAccounts couldn't fetch accounts")
|
||||
}
|
||||
|
||||
table := L.NewTable()
|
||||
|
||||
for accountid := range account_map {
|
||||
table.RawSetInt(int(accountid), AccountToLua(L, account_map[accountid]))
|
||||
}
|
||||
|
||||
L.Push(table)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaRegisterAccounts(L *lua.LState) {
|
||||
mt := L.NewTypeMetatable(luaAccountTypeName)
|
||||
L.SetGlobal("account", mt)
|
||||
L.SetField(mt, "__index", L.NewFunction(luaAccount__index))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(luaAccount__tostring))
|
||||
L.SetField(mt, "__eq", L.NewFunction(luaAccount__eq))
|
||||
L.SetField(mt, "__metatable", lua.LString("protected"))
|
||||
|
||||
for _, accttype := range AccountTypes {
|
||||
L.SetField(mt, accttype.String(), lua.LNumber(float64(accttype)))
|
||||
}
|
||||
|
||||
getAccountsFn := L.NewFunction(luaGetAccounts)
|
||||
L.SetField(mt, "get_all", getAccountsFn)
|
||||
// also register the get_accounts function as a global in its own right
|
||||
L.SetGlobal("get_accounts", getAccountsFn)
|
||||
}
|
||||
|
||||
func AccountToLua(L *lua.LState, account *Account) *lua.LUserData {
|
||||
ud := L.NewUserData()
|
||||
ud.Value = account
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaAccountTypeName))
|
||||
return ud
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Account and returns this *Account.
|
||||
func luaCheckAccount(L *lua.LState, n int) *Account {
|
||||
ud := L.CheckUserData(n)
|
||||
if account, ok := ud.Value.(*Account); ok {
|
||||
return account
|
||||
}
|
||||
L.ArgError(n, "account expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaAccount__index(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "AccountId", "accountid":
|
||||
L.Push(lua.LNumber(float64(a.AccountId)))
|
||||
case "Security", "security":
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
panic("account.security couldn't fetch securities")
|
||||
}
|
||||
if security, ok := security_map[a.SecurityId]; ok {
|
||||
L.Push(SecurityToLua(L, security))
|
||||
} else {
|
||||
panic("SecurityId not in lua security_map")
|
||||
}
|
||||
case "SecurityId", "securityid":
|
||||
L.Push(lua.LNumber(float64(a.SecurityId)))
|
||||
case "Parent", "parent", "ParentAccount", "parentaccount":
|
||||
if a.ParentAccountId == -1 {
|
||||
L.Push(lua.LNil)
|
||||
} else {
|
||||
account_map, err := luaContextGetAccounts(L)
|
||||
if err != nil {
|
||||
panic("account.parent couldn't fetch accounts")
|
||||
}
|
||||
if parent, ok := account_map[a.ParentAccountId]; ok {
|
||||
L.Push(AccountToLua(L, parent))
|
||||
} else {
|
||||
panic("ParentAccountId not in lua account_map")
|
||||
}
|
||||
}
|
||||
case "Name", "name":
|
||||
L.Push(lua.LString(a.Name))
|
||||
case "Type", "type":
|
||||
L.Push(lua.LNumber(float64(a.Type)))
|
||||
case "TypeName", "Typename":
|
||||
L.Push(lua.LString(a.Type.String()))
|
||||
case "typename":
|
||||
L.Push(lua.LString(strings.ToLower(a.Type.String())))
|
||||
case "Balance", "balance":
|
||||
L.Push(L.NewFunction(luaAccountBalance))
|
||||
default:
|
||||
L.ArgError(2, "unexpected account attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaAccountBalance(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
|
||||
ctx := L.Context()
|
||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||
if !ok {
|
||||
panic("Couldn't find DB in lua's Context")
|
||||
}
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
panic("Couldn't find User in lua's Context")
|
||||
}
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
panic("account.security couldn't fetch securities")
|
||||
}
|
||||
security, ok := security_map[a.SecurityId]
|
||||
if !ok {
|
||||
panic("SecurityId not in lua security_map")
|
||||
}
|
||||
date := luaWeakCheckTime(L, 2)
|
||||
var b Balance
|
||||
var rat *big.Rat
|
||||
if date != nil {
|
||||
end := luaWeakCheckTime(L, 3)
|
||||
if end != nil {
|
||||
rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end)
|
||||
} else {
|
||||
rat, err = GetAccountBalanceDate(db, user, a.AccountId, date)
|
||||
}
|
||||
} else {
|
||||
rat, err = GetAccountBalance(db, user, a.AccountId)
|
||||
}
|
||||
if err != nil {
|
||||
panic("Failed to GetAccountBalance:" + err.Error())
|
||||
}
|
||||
b.Amount = rat
|
||||
b.Security = security
|
||||
L.Push(BalanceToLua(L, &b))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaAccount__tostring(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
|
||||
account_map, err := luaContextGetAccounts(L)
|
||||
if err != nil {
|
||||
panic("luaGetAccounts couldn't fetch accounts")
|
||||
}
|
||||
|
||||
full_name := a.Name
|
||||
for a.ParentAccountId != -1 {
|
||||
a = account_map[a.ParentAccountId]
|
||||
full_name = a.Name + "/" + full_name
|
||||
}
|
||||
|
||||
L.Push(lua.LString(full_name))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaAccount__eq(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
b := luaCheckAccount(L, 2)
|
||||
|
||||
L.Push(lua.LBool(a.AccountId == b.AccountId))
|
||||
|
||||
return 1
|
||||
}
|
224
internal/handlers/balance_lua.go
Normal file
224
internal/handlers/balance_lua.go
Normal file
@ -0,0 +1,224 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/yuin/gopher-lua"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
type Balance struct {
|
||||
Security *Security
|
||||
Amount *big.Rat
|
||||
}
|
||||
|
||||
const luaBalanceTypeName = "balance"
|
||||
|
||||
func luaRegisterBalances(L *lua.LState) {
|
||||
mt := L.NewTypeMetatable(luaBalanceTypeName)
|
||||
L.SetGlobal("balance", mt)
|
||||
L.SetField(mt, "__index", L.NewFunction(luaBalance__index))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(luaBalance__tostring))
|
||||
L.SetField(mt, "__eq", L.NewFunction(luaBalance__eq))
|
||||
L.SetField(mt, "__lt", L.NewFunction(luaBalance__lt))
|
||||
L.SetField(mt, "__le", L.NewFunction(luaBalance__le))
|
||||
L.SetField(mt, "__add", L.NewFunction(luaBalance__add))
|
||||
L.SetField(mt, "__sub", L.NewFunction(luaBalance__sub))
|
||||
L.SetField(mt, "__mul", L.NewFunction(luaBalance__mul))
|
||||
L.SetField(mt, "__div", L.NewFunction(luaBalance__div))
|
||||
L.SetField(mt, "__unm", L.NewFunction(luaBalance__unm))
|
||||
L.SetField(mt, "__metatable", lua.LString("protected"))
|
||||
}
|
||||
|
||||
func BalanceToLua(L *lua.LState, balance *Balance) *lua.LUserData {
|
||||
ud := L.NewUserData()
|
||||
ud.Value = balance
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaBalanceTypeName))
|
||||
return ud
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Balance and returns this *Balance.
|
||||
func luaCheckBalance(L *lua.LState, n int) *Balance {
|
||||
ud := L.CheckUserData(n)
|
||||
if balance, ok := ud.Value.(*Balance); ok {
|
||||
return balance
|
||||
}
|
||||
L.ArgError(n, "balance expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaWeakCheckBalance(L *lua.LState, n int) *Balance {
|
||||
v := L.Get(n)
|
||||
if ud, ok := v.(*lua.LUserData); ok {
|
||||
if balance, ok := ud.Value.(*Balance); ok {
|
||||
return balance
|
||||
}
|
||||
L.ArgError(n, "balance expected")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaGetBalanceOperands(L *lua.LState, n int, m int) (*Balance, *Balance) {
|
||||
bn := luaWeakCheckBalance(L, n)
|
||||
bm := luaWeakCheckBalance(L, m)
|
||||
|
||||
if bn != nil && bm != nil {
|
||||
return bn, bm
|
||||
} else if bn != nil {
|
||||
nm := L.CheckNumber(m)
|
||||
var balance Balance
|
||||
var rat big.Rat
|
||||
balance.Security = bn.Security
|
||||
balance.Amount = rat.SetFloat64(float64(nm))
|
||||
if balance.Amount == nil {
|
||||
L.ArgError(n, "non-finite float invalid for operand to balance arithemetic")
|
||||
return nil, nil
|
||||
}
|
||||
return bn, &balance
|
||||
} else if bm != nil {
|
||||
nn := L.CheckNumber(n)
|
||||
var balance Balance
|
||||
var rat big.Rat
|
||||
balance.Security = bm.Security
|
||||
balance.Amount = rat.SetFloat64(float64(nn))
|
||||
if balance.Amount == nil {
|
||||
L.ArgError(n, "non-finite float invalid for operand to balance arithemetic")
|
||||
return nil, nil
|
||||
}
|
||||
return bm, &balance
|
||||
}
|
||||
L.ArgError(n, "balance expected")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func luaBalance__index(L *lua.LState) int {
|
||||
a := luaCheckBalance(L, 1)
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "Security", "security":
|
||||
L.Push(SecurityToLua(L, a.Security))
|
||||
case "Amount", "amount":
|
||||
float, _ := a.Amount.Float64()
|
||||
L.Push(lua.LNumber(float))
|
||||
default:
|
||||
L.ArgError(2, "unexpected balance attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__tostring(L *lua.LState) int {
|
||||
b := luaCheckBalance(L, 1)
|
||||
|
||||
L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(b.Security.Precision)))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__eq(L *lua.LState) int {
|
||||
a := luaCheckBalance(L, 1)
|
||||
b := luaCheckBalance(L, 2)
|
||||
|
||||
L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(b.Amount) == 0))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__lt(L *lua.LState) int {
|
||||
a := luaCheckBalance(L, 1)
|
||||
b := luaCheckBalance(L, 2)
|
||||
if a.Security.SecurityId != b.Security.SecurityId {
|
||||
L.ArgError(2, "Can't compare balances with different securities")
|
||||
}
|
||||
|
||||
L.Push(lua.LBool(a.Amount.Cmp(b.Amount) < 0))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__le(L *lua.LState) int {
|
||||
a := luaCheckBalance(L, 1)
|
||||
b := luaCheckBalance(L, 2)
|
||||
if a.Security.SecurityId != b.Security.SecurityId {
|
||||
L.ArgError(2, "Can't compare balances with different securities")
|
||||
}
|
||||
|
||||
L.Push(lua.LBool(a.Amount.Cmp(b.Amount) <= 0))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__add(L *lua.LState) int {
|
||||
a, b := luaGetBalanceOperands(L, 1, 2)
|
||||
|
||||
if a.Security.SecurityId != b.Security.SecurityId {
|
||||
L.ArgError(2, "Can't add balances with different securities")
|
||||
}
|
||||
|
||||
var balance Balance
|
||||
var rat big.Rat
|
||||
balance.Security = a.Security
|
||||
balance.Amount = rat.Add(a.Amount, b.Amount)
|
||||
L.Push(BalanceToLua(L, &balance))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__sub(L *lua.LState) int {
|
||||
a, b := luaGetBalanceOperands(L, 1, 2)
|
||||
|
||||
if a.Security.SecurityId != b.Security.SecurityId {
|
||||
L.ArgError(2, "Can't subtract balances with different securities")
|
||||
}
|
||||
|
||||
var balance Balance
|
||||
var rat big.Rat
|
||||
balance.Security = a.Security
|
||||
balance.Amount = rat.Sub(a.Amount, b.Amount)
|
||||
L.Push(BalanceToLua(L, &balance))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__mul(L *lua.LState) int {
|
||||
a, b := luaGetBalanceOperands(L, 1, 2)
|
||||
|
||||
if a.Security.SecurityId != b.Security.SecurityId {
|
||||
L.ArgError(2, "Can't multiply balances with different securities")
|
||||
}
|
||||
|
||||
var balance Balance
|
||||
var rat big.Rat
|
||||
balance.Security = a.Security
|
||||
balance.Amount = rat.Mul(a.Amount, b.Amount)
|
||||
L.Push(BalanceToLua(L, &balance))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__div(L *lua.LState) int {
|
||||
a, b := luaGetBalanceOperands(L, 1, 2)
|
||||
|
||||
if a.Security.SecurityId != b.Security.SecurityId {
|
||||
L.ArgError(2, "Can't divide balances with different securities")
|
||||
}
|
||||
|
||||
var balance Balance
|
||||
var rat big.Rat
|
||||
balance.Security = a.Security
|
||||
balance.Amount = rat.Quo(a.Amount, b.Amount)
|
||||
L.Push(BalanceToLua(L, &balance))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaBalance__unm(L *lua.LState) int {
|
||||
b := luaCheckBalance(L, 1)
|
||||
|
||||
var balance Balance
|
||||
var rat big.Rat
|
||||
balance.Security = b.Security
|
||||
balance.Amount = rat.Neg(b.Amount)
|
||||
L.Push(BalanceToLua(L, &balance))
|
||||
|
||||
return 1
|
||||
}
|
169
internal/handlers/date_lua.go
Normal file
169
internal/handlers/date_lua.go
Normal file
@ -0,0 +1,169 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/yuin/gopher-lua"
|
||||
"time"
|
||||
)
|
||||
|
||||
const luaDateTypeName = "date"
|
||||
const timeFormat = "2006-01-02"
|
||||
|
||||
func luaRegisterDates(L *lua.LState) {
|
||||
mt := L.NewTypeMetatable(luaDateTypeName)
|
||||
L.SetGlobal("date", mt)
|
||||
L.SetField(mt, "new", L.NewFunction(luaDateNew))
|
||||
L.SetField(mt, "now", L.NewFunction(luaDateNow))
|
||||
L.SetField(mt, "__index", L.NewFunction(luaDate__index))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(luaDate__tostring))
|
||||
L.SetField(mt, "__eq", L.NewFunction(luaDate__eq))
|
||||
L.SetField(mt, "__lt", L.NewFunction(luaDate__lt))
|
||||
L.SetField(mt, "__le", L.NewFunction(luaDate__le))
|
||||
L.SetField(mt, "__add", L.NewFunction(luaDate__add))
|
||||
L.SetField(mt, "__sub", L.NewFunction(luaDate__sub))
|
||||
L.SetField(mt, "__metatable", lua.LString("protected"))
|
||||
}
|
||||
|
||||
func TimeToLua(L *lua.LState, date *time.Time) *lua.LUserData {
|
||||
ud := L.NewUserData()
|
||||
ud.Value = date
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaDateTypeName))
|
||||
return ud
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Time and returns this *Time.
|
||||
func luaCheckTime(L *lua.LState, n int) *time.Time {
|
||||
ud := L.CheckUserData(n)
|
||||
if date, ok := ud.Value.(*time.Time); ok {
|
||||
return date
|
||||
}
|
||||
L.ArgError(n, "date expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaWeakCheckTime(L *lua.LState, n int) *time.Time {
|
||||
v := L.Get(n)
|
||||
if ud, ok := v.(*lua.LUserData); ok {
|
||||
if date, ok := ud.Value.(*time.Time); ok {
|
||||
return date
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaWeakCheckTableFieldInt(L *lua.LState, T *lua.LTable, n int, name string, def int) int {
|
||||
lv := T.RawGetString(name)
|
||||
if lv == lua.LNil {
|
||||
return def
|
||||
}
|
||||
if i, ok := lv.(lua.LNumber); ok {
|
||||
return int(i)
|
||||
}
|
||||
L.ArgError(n, "table field '"+name+"' expected to be int")
|
||||
return def
|
||||
}
|
||||
|
||||
func luaDateNew(L *lua.LState) int {
|
||||
v := L.Get(1)
|
||||
if s, ok := v.(lua.LString); ok {
|
||||
date, err := time.Parse(timeFormat, s.String())
|
||||
if err != nil {
|
||||
L.ArgError(1, "error parsing date string: "+err.Error())
|
||||
return 0
|
||||
}
|
||||
L.Push(TimeToLua(L, &date))
|
||||
return 1
|
||||
}
|
||||
var year, month, day int
|
||||
if t, ok := v.(*lua.LTable); ok {
|
||||
year = luaWeakCheckTableFieldInt(L, t, 1, "year", 0)
|
||||
month = luaWeakCheckTableFieldInt(L, t, 1, "month", 1)
|
||||
day = luaWeakCheckTableFieldInt(L, t, 1, "day", 1)
|
||||
} else {
|
||||
year = L.CheckInt(1)
|
||||
month = L.CheckInt(2)
|
||||
day = L.CheckInt(3)
|
||||
}
|
||||
date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.Local)
|
||||
L.Push(TimeToLua(L, &date))
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDateNow(L *lua.LState) int {
|
||||
now := time.Now()
|
||||
date := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local)
|
||||
L.Push(TimeToLua(L, &date))
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDate__index(L *lua.LState) int {
|
||||
d := luaCheckTime(L, 1)
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "Year", "year":
|
||||
L.Push(lua.LNumber(d.Year()))
|
||||
case "Month", "month":
|
||||
L.Push(lua.LNumber(float64(d.Month())))
|
||||
case "Day", "day":
|
||||
L.Push(lua.LNumber(float64(d.Day())))
|
||||
default:
|
||||
L.ArgError(2, "unexpected date attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDate__tostring(L *lua.LState) int {
|
||||
a := luaCheckTime(L, 1)
|
||||
|
||||
L.Push(lua.LString(a.Format(timeFormat)))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDate__eq(L *lua.LState) int {
|
||||
a := luaCheckTime(L, 1)
|
||||
b := luaCheckTime(L, 2)
|
||||
|
||||
L.Push(lua.LBool(a.Equal(*b)))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDate__lt(L *lua.LState) int {
|
||||
a := luaCheckTime(L, 1)
|
||||
b := luaCheckTime(L, 2)
|
||||
|
||||
L.Push(lua.LBool(a.Before(*b)))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDate__le(L *lua.LState) int {
|
||||
a := luaCheckTime(L, 1)
|
||||
b := luaCheckTime(L, 2)
|
||||
|
||||
L.Push(lua.LBool(a.Equal(*b) || a.Before(*b)))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDate__add(L *lua.LState) int {
|
||||
a := luaCheckTime(L, 1)
|
||||
b := luaCheckTime(L, 2)
|
||||
|
||||
date := a.AddDate(b.Year(), int(b.Month()), b.Day())
|
||||
L.Push(TimeToLua(L, &date))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDate__sub(L *lua.LState) int {
|
||||
a := luaCheckTime(L, 1)
|
||||
b := luaCheckTime(L, 2)
|
||||
|
||||
date := a.AddDate(-b.Year(), -int(b.Month()), -b.Day())
|
||||
L.Push(TimeToLua(L, &date))
|
||||
|
||||
return 1
|
||||
}
|
37
internal/handlers/errors.go
Normal file
37
internal/handlers/errors.go
Normal file
@ -0,0 +1,37 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
ErrorId int
|
||||
ErrorString string
|
||||
}
|
||||
|
||||
var error_codes = map[int]string{
|
||||
1: "Not Signed In",
|
||||
2: "Unauthorized Access",
|
||||
3: "Invalid Request",
|
||||
4: "User Exists",
|
||||
// 5: "Connection Failed", //reserved for client-side error
|
||||
6: "Import Error",
|
||||
999: "Internal Error",
|
||||
}
|
||||
|
||||
func WriteError(w http.ResponseWriter, error_code int) {
|
||||
msg, ok := error_codes[error_code]
|
||||
if !ok {
|
||||
log.Printf("Error: WriteError received error code of %d", error_code)
|
||||
msg = error_codes[999]
|
||||
}
|
||||
e := Error{error_code, msg}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
err := enc.Encode(e)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
493
internal/handlers/gnucash.go
Normal file
493
internal/handlers/gnucash.go
Normal file
@ -0,0 +1,493 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GnucashXMLCommodity struct {
|
||||
Name string `xml:"http://www.gnucash.org/XML/cmdty id"`
|
||||
Description string `xml:"http://www.gnucash.org/XML/cmdty name"`
|
||||
Type string `xml:"http://www.gnucash.org/XML/cmdty space"`
|
||||
Fraction int `xml:"http://www.gnucash.org/XML/cmdty fraction"`
|
||||
XCode string `xml:"http://www.gnucash.org/XML/cmdty xcode"`
|
||||
}
|
||||
|
||||
type GnucashCommodity struct{ Security }
|
||||
|
||||
func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
||||
var gxc GnucashXMLCommodity
|
||||
if err := d.DecodeElement(&gxc, &start); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
gc.Name = gxc.Name
|
||||
gc.Symbol = gxc.Name
|
||||
gc.Description = gxc.Description
|
||||
gc.AlternateId = gxc.XCode
|
||||
|
||||
gc.Security.Type = Stock // assumed default
|
||||
if gxc.Type == "ISO4217" {
|
||||
gc.Security.Type = Currency
|
||||
// Get the number from our templates for the AlternateId because
|
||||
// Gnucash uses 'id' (our Name) to supply the string ISO4217 code
|
||||
template := FindSecurityTemplate(gxc.Name, Currency)
|
||||
if template == nil {
|
||||
return errors.New("Unable to find security template for Gnucash ISO4217 commodity")
|
||||
}
|
||||
gc.AlternateId = template.AlternateId
|
||||
gc.Precision = template.Precision
|
||||
} else {
|
||||
if gxc.Fraction > 0 {
|
||||
gc.Precision = int(math.Ceil(math.Log10(float64(gxc.Fraction))))
|
||||
} else {
|
||||
gc.Precision = 0
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GnucashTime struct{ time.Time }
|
||||
|
||||
func (g *GnucashTime) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
||||
var s string
|
||||
if err := d.DecodeElement(&s, &start); err != nil {
|
||||
return fmt.Errorf("date should be a string")
|
||||
}
|
||||
t, err := time.Parse("2006-01-02 15:04:05 -0700", s)
|
||||
g.Time = t
|
||||
return err
|
||||
}
|
||||
|
||||
type GnucashDate struct {
|
||||
Date GnucashTime `xml:"http://www.gnucash.org/XML/ts date"`
|
||||
}
|
||||
|
||||
type GnucashPrice struct {
|
||||
Id string `xml:"http://www.gnucash.org/XML/price id"`
|
||||
Commodity GnucashCommodity `xml:"http://www.gnucash.org/XML/price commodity"`
|
||||
Currency GnucashCommodity `xml:"http://www.gnucash.org/XML/price currency"`
|
||||
Date GnucashDate `xml:"http://www.gnucash.org/XML/price time"`
|
||||
Source string `xml:"http://www.gnucash.org/XML/price source"`
|
||||
Type string `xml:"http://www.gnucash.org/XML/price type"`
|
||||
Value string `xml:"http://www.gnucash.org/XML/price value"`
|
||||
}
|
||||
|
||||
type GnucashPriceDB struct {
|
||||
Prices []GnucashPrice `xml:"price"`
|
||||
}
|
||||
|
||||
type GnucashAccount struct {
|
||||
Version string `xml:"version,attr"`
|
||||
accountid int64 // Used to map Gnucash guid's to integer ones
|
||||
AccountId string `xml:"http://www.gnucash.org/XML/act id"`
|
||||
ParentAccountId string `xml:"http://www.gnucash.org/XML/act parent"`
|
||||
Name string `xml:"http://www.gnucash.org/XML/act name"`
|
||||
Description string `xml:"http://www.gnucash.org/XML/act description"`
|
||||
Type string `xml:"http://www.gnucash.org/XML/act type"`
|
||||
Commodity GnucashXMLCommodity `xml:"http://www.gnucash.org/XML/act commodity"`
|
||||
}
|
||||
|
||||
type GnucashTransaction struct {
|
||||
TransactionId string `xml:"http://www.gnucash.org/XML/trn id"`
|
||||
Description string `xml:"http://www.gnucash.org/XML/trn description"`
|
||||
Number string `xml:"http://www.gnucash.org/XML/trn num"`
|
||||
DatePosted GnucashDate `xml:"http://www.gnucash.org/XML/trn date-posted"`
|
||||
DateEntered GnucashDate `xml:"http://www.gnucash.org/XML/trn date-entered"`
|
||||
Commodity GnucashXMLCommodity `xml:"http://www.gnucash.org/XML/trn currency"`
|
||||
Splits []GnucashSplit `xml:"http://www.gnucash.org/XML/trn splits>split"`
|
||||
}
|
||||
|
||||
type GnucashSplit struct {
|
||||
SplitId string `xml:"http://www.gnucash.org/XML/split id"`
|
||||
Status string `xml:"http://www.gnucash.org/XML/split reconciled-state"`
|
||||
AccountId string `xml:"http://www.gnucash.org/XML/split account"`
|
||||
Memo string `xml:"http://www.gnucash.org/XML/split memo"`
|
||||
Amount string `xml:"http://www.gnucash.org/XML/split quantity"`
|
||||
Value string `xml:"http://www.gnucash.org/XML/split value"`
|
||||
}
|
||||
|
||||
type GnucashXMLImport struct {
|
||||
XMLName xml.Name `xml:"gnc-v2"`
|
||||
Commodities []GnucashCommodity `xml:"http://www.gnucash.org/XML/gnc book>commodity"`
|
||||
PriceDB GnucashPriceDB `xml:"http://www.gnucash.org/XML/gnc book>pricedb"`
|
||||
Accounts []GnucashAccount `xml:"http://www.gnucash.org/XML/gnc book>account"`
|
||||
Transactions []GnucashTransaction `xml:"http://www.gnucash.org/XML/gnc book>transaction"`
|
||||
}
|
||||
|
||||
type GnucashImport struct {
|
||||
Securities []Security
|
||||
Accounts []Account
|
||||
Transactions []Transaction
|
||||
Prices []Price
|
||||
}
|
||||
|
||||
func ImportGnucash(r io.Reader) (*GnucashImport, error) {
|
||||
var gncxml GnucashXMLImport
|
||||
var gncimport GnucashImport
|
||||
|
||||
// Perform initial parsing of xml into structs
|
||||
decoder := xml.NewDecoder(r)
|
||||
err := decoder.Decode(&gncxml)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fixup securities, making a map of them as we go
|
||||
securityMap := make(map[string]Security)
|
||||
for i := range gncxml.Commodities {
|
||||
s := gncxml.Commodities[i].Security
|
||||
s.SecurityId = int64(i + 1)
|
||||
securityMap[s.Name] = s
|
||||
|
||||
// Ignore gnucash's "template" commodity
|
||||
if s.Name != "template" ||
|
||||
s.Description != "template" ||
|
||||
s.AlternateId != "template" {
|
||||
gncimport.Securities = append(gncimport.Securities, s)
|
||||
}
|
||||
}
|
||||
|
||||
// Create prices, setting security and currency IDs from securityMap
|
||||
for i := range gncxml.PriceDB.Prices {
|
||||
price := gncxml.PriceDB.Prices[i]
|
||||
var p Price
|
||||
security, ok := securityMap[price.Commodity.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Unable to find commodity '%s' for price '%s'", price.Commodity.Name, price.Id)
|
||||
}
|
||||
currency, ok := securityMap[price.Currency.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Unable to find currency '%s' for price '%s'", price.Currency.Name, price.Id)
|
||||
}
|
||||
if currency.Type != Currency {
|
||||
return nil, fmt.Errorf("Currency for imported price isn't actually a currency\n")
|
||||
}
|
||||
p.PriceId = int64(i + 1)
|
||||
p.SecurityId = security.SecurityId
|
||||
p.CurrencyId = currency.SecurityId
|
||||
p.Date = price.Date.Date.Time
|
||||
|
||||
var r big.Rat
|
||||
_, ok = r.SetString(price.Value)
|
||||
if ok {
|
||||
p.Value = r.FloatString(currency.Precision)
|
||||
} else {
|
||||
return nil, fmt.Errorf("Can't set price value: %s", price.Value)
|
||||
}
|
||||
|
||||
p.RemoteId = "gnucash:" + price.Id
|
||||
gncimport.Prices = append(gncimport.Prices, p)
|
||||
}
|
||||
|
||||
//find root account, while simultaneously creating map of GUID's to
|
||||
//accounts
|
||||
var rootAccount GnucashAccount
|
||||
accountMap := make(map[string]GnucashAccount)
|
||||
for i := range gncxml.Accounts {
|
||||
gncxml.Accounts[i].accountid = int64(i + 1)
|
||||
if gncxml.Accounts[i].Type == "ROOT" {
|
||||
rootAccount = gncxml.Accounts[i]
|
||||
} else {
|
||||
accountMap[gncxml.Accounts[i].AccountId] = gncxml.Accounts[i]
|
||||
}
|
||||
}
|
||||
|
||||
//Translate to our account format, figuring out parent relationships
|
||||
for guid := range accountMap {
|
||||
ga := accountMap[guid]
|
||||
var a Account
|
||||
|
||||
a.AccountId = ga.accountid
|
||||
if ga.ParentAccountId == rootAccount.AccountId {
|
||||
a.ParentAccountId = -1
|
||||
} else {
|
||||
parent, ok := accountMap[ga.ParentAccountId]
|
||||
if ok {
|
||||
a.ParentAccountId = parent.accountid
|
||||
} else {
|
||||
a.ParentAccountId = -1 // Ugly, but assign to top-level if we can't find its parent
|
||||
}
|
||||
}
|
||||
a.Name = ga.Name
|
||||
if security, ok := securityMap[ga.Commodity.Name]; ok {
|
||||
a.SecurityId = security.SecurityId
|
||||
} else {
|
||||
return nil, fmt.Errorf("Unable to find security: %s", ga.Commodity.Name)
|
||||
}
|
||||
|
||||
//TODO find account types
|
||||
switch ga.Type {
|
||||
default:
|
||||
a.Type = Bank
|
||||
case "ASSET":
|
||||
a.Type = Asset
|
||||
case "BANK":
|
||||
a.Type = Bank
|
||||
case "CASH":
|
||||
a.Type = Cash
|
||||
case "CREDIT", "LIABILITY":
|
||||
a.Type = Liability
|
||||
case "EQUITY":
|
||||
a.Type = Equity
|
||||
case "EXPENSE":
|
||||
a.Type = Expense
|
||||
case "INCOME":
|
||||
a.Type = Income
|
||||
case "PAYABLE":
|
||||
a.Type = Payable
|
||||
case "RECEIVABLE":
|
||||
a.Type = Receivable
|
||||
case "MUTUAL", "STOCK":
|
||||
a.Type = Investment
|
||||
case "TRADING":
|
||||
a.Type = Trading
|
||||
}
|
||||
|
||||
gncimport.Accounts = append(gncimport.Accounts, a)
|
||||
}
|
||||
|
||||
//Translate transactions to our format
|
||||
for i := range gncxml.Transactions {
|
||||
gt := gncxml.Transactions[i]
|
||||
|
||||
t := new(Transaction)
|
||||
t.Description = gt.Description
|
||||
t.Date = gt.DatePosted.Date.Time
|
||||
for j := range gt.Splits {
|
||||
gs := gt.Splits[j]
|
||||
s := new(Split)
|
||||
|
||||
switch gs.Status {
|
||||
default: // 'n', or not present
|
||||
s.Status = Imported
|
||||
case "c":
|
||||
s.Status = Cleared
|
||||
case "y":
|
||||
s.Status = Reconciled
|
||||
}
|
||||
|
||||
account, ok := accountMap[gs.AccountId]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Unable to find account: %s", gs.AccountId)
|
||||
}
|
||||
s.AccountId = account.accountid
|
||||
|
||||
security, ok := securityMap[account.Commodity.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Unable to find security: %s", account.Commodity.Name)
|
||||
}
|
||||
s.SecurityId = -1
|
||||
|
||||
s.RemoteId = "gnucash:" + gs.SplitId
|
||||
s.Number = gt.Number
|
||||
s.Memo = gs.Memo
|
||||
|
||||
var r big.Rat
|
||||
_, ok = r.SetString(gs.Amount)
|
||||
if ok {
|
||||
s.Amount = r.FloatString(security.Precision)
|
||||
} else {
|
||||
return nil, fmt.Errorf("Can't set split Amount: %s", gs.Amount)
|
||||
}
|
||||
|
||||
t.Splits = append(t.Splits, s)
|
||||
}
|
||||
gncimport.Transactions = append(gncimport.Transactions, *t)
|
||||
}
|
||||
|
||||
return &gncimport, nil
|
||||
}
|
||||
|
||||
func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != "POST" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
multipartReader, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
// Assume there is only one 'part' and it's the one we care about
|
||||
part, err := multipartReader.NextPart()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
} else {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
bufread := bufio.NewReader(part)
|
||||
gzHeader, err := bufread.Peek(2)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Does this look like a gzipped file?
|
||||
var gnucashImport *GnucashImport
|
||||
if gzHeader[0] == 0x1f && gzHeader[1] == 0x8b {
|
||||
gzr, err := gzip.NewReader(bufread)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
gnucashImport, err = ImportGnucash(gzr)
|
||||
} else {
|
||||
gnucashImport, err = ImportGnucash(bufread)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
sqltransaction, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Import securities, building map from Gnucash security IDs to our
|
||||
// internal IDs
|
||||
securityMap := make(map[int64]int64)
|
||||
for _, security := range gnucashImport.Securities {
|
||||
securityId := security.SecurityId // save off because it could be updated
|
||||
s, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &security)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 6 /*Import Error*/)
|
||||
log.Print(err)
|
||||
log.Print(security)
|
||||
return
|
||||
}
|
||||
securityMap[securityId] = s.SecurityId
|
||||
}
|
||||
|
||||
// Import prices, setting security and currency IDs from securityMap
|
||||
for _, price := range gnucashImport.Prices {
|
||||
price.SecurityId = securityMap[price.SecurityId]
|
||||
price.CurrencyId = securityMap[price.CurrencyId]
|
||||
price.PriceId = 0
|
||||
|
||||
err := CreatePriceIfNotExist(sqltransaction, &price)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 6 /*Import Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get/create accounts in the database, building a map from Gnucash account
|
||||
// IDs to our internal IDs as we go
|
||||
accountMap := make(map[int64]int64)
|
||||
accountsRemaining := len(gnucashImport.Accounts)
|
||||
accountsRemainingLast := accountsRemaining
|
||||
for accountsRemaining > 0 {
|
||||
for _, account := range gnucashImport.Accounts {
|
||||
|
||||
// If the account has already been added to the map, skip it
|
||||
_, ok := accountMap[account.AccountId]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// If it hasn't been added, but its parent has, add it to the map
|
||||
_, ok = accountMap[account.ParentAccountId]
|
||||
if ok || account.ParentAccountId == -1 {
|
||||
account.UserId = user.UserId
|
||||
if account.ParentAccountId != -1 {
|
||||
account.ParentAccountId = accountMap[account.ParentAccountId]
|
||||
}
|
||||
account.SecurityId = securityMap[account.SecurityId]
|
||||
a, err := GetCreateAccountTx(sqltransaction, account)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
accountMap[account.AccountId] = a.AccountId
|
||||
accountsRemaining--
|
||||
}
|
||||
}
|
||||
if accountsRemaining == accountsRemainingLast {
|
||||
//We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName()))
|
||||
return
|
||||
}
|
||||
accountsRemainingLast = accountsRemaining
|
||||
}
|
||||
|
||||
// Insert transactions, fixing up account IDs to match internal ones from
|
||||
// above
|
||||
for _, transaction := range gnucashImport.Transactions {
|
||||
var already_imported bool
|
||||
for _, split := range transaction.Splits {
|
||||
acctId, ok := accountMap[split.AccountId]
|
||||
if !ok {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId))
|
||||
return
|
||||
}
|
||||
split.AccountId = acctId
|
||||
|
||||
exists, err := split.AlreadyImportedTx(sqltransaction)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print("Error checking if split was already imported:", err)
|
||||
return
|
||||
} else if exists {
|
||||
already_imported = true
|
||||
}
|
||||
}
|
||||
if !already_imported {
|
||||
err := InsertTransactionTx(sqltransaction, &transaction, user)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = sqltransaction.Commit()
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
WriteSuccess(w)
|
||||
}
|
31
internal/handlers/handlers.go
Normal file
31
internal/handlers/handlers.go
Normal file
@ -0,0 +1,31 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gopkg.in/gorp.v1"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Create a closure over db, allowing the handlers to look like a
|
||||
// http.HandlerFunc
|
||||
type DB = gorp.DbMap
|
||||
type DBHandler func(http.ResponseWriter, *http.Request, *DB)
|
||||
|
||||
func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
h(w, r, db)
|
||||
}
|
||||
}
|
||||
|
||||
func GetHandler(db *DB) *http.ServeMux {
|
||||
servemux := http.NewServeMux()
|
||||
servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db))
|
||||
servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db))
|
||||
servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db))
|
||||
servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler)
|
||||
servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db))
|
||||
servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db))
|
||||
servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db))
|
||||
servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db))
|
||||
|
||||
return servemux
|
||||
}
|
409
internal/handlers/imports.go
Normal file
409
internal/handlers/imports.go
Normal file
@ -0,0 +1,409 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/aclindsa/ofxgo"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OFXDownload struct {
|
||||
OFXPassword string
|
||||
StartDate time.Time
|
||||
EndDate time.Time
|
||||
}
|
||||
|
||||
func (od *OFXDownload) Read(json_str string) error {
|
||||
dec := json.NewDecoder(strings.NewReader(json_str))
|
||||
return dec.Decode(od)
|
||||
}
|
||||
|
||||
func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
|
||||
itl, err := ImportOFX(r)
|
||||
|
||||
if err != nil {
|
||||
//TODO is this necessarily an invalid request (what if it was an error on our end)?
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(itl.Accounts) != 1 {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
log.Printf("Found %d accounts when importing OFX, expected 1", len(itl.Accounts))
|
||||
return
|
||||
}
|
||||
|
||||
sqltransaction, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return Account with this Id
|
||||
account, err := GetAccountTx(sqltransaction, accountid, user.UserId)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
importedAccount := itl.Accounts[0]
|
||||
|
||||
if len(account.ExternalAccountId) > 0 &&
|
||||
account.ExternalAccountId != importedAccount.ExternalAccountId {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
log.Printf("OFX import has \"%s\" as ExternalAccountId, but the account being imported to has\"%s\"",
|
||||
importedAccount.ExternalAccountId,
|
||||
account.ExternalAccountId)
|
||||
return
|
||||
}
|
||||
|
||||
// Find matching existing securities or create new ones for those
|
||||
// referenced by the OFX import. Also create a map from placeholder import
|
||||
// SecurityIds to the actual SecurityIDs
|
||||
var securitymap = make(map[int64]Security)
|
||||
for _, ofxsecurity := range itl.Securities {
|
||||
// save off since ImportGetCreateSecurity overwrites SecurityId on
|
||||
// ofxsecurity
|
||||
oldsecurityid := ofxsecurity.SecurityId
|
||||
security, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &ofxsecurity)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
securitymap[oldsecurityid] = *security
|
||||
}
|
||||
|
||||
if account.SecurityId != securitymap[importedAccount.SecurityId].SecurityId {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
log.Printf("OFX import account's SecurityId (%d) does not match this account's (%d)", securitymap[importedAccount.SecurityId].SecurityId, account.SecurityId)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO Ensure all transactions have at least one split in the account
|
||||
// we're importing to?
|
||||
|
||||
var transactions []Transaction
|
||||
for _, transaction := range itl.Transactions {
|
||||
transaction.UserId = user.UserId
|
||||
|
||||
if !transaction.Valid() {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print("Unexpected invalid transaction from OFX import")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure that either AccountId or SecurityId is set for this split,
|
||||
// and fixup the SecurityId to be a valid one for this user's actual
|
||||
// securities instead of a placeholder from the import
|
||||
for _, split := range transaction.Splits {
|
||||
split.Status = Imported
|
||||
if split.AccountId != -1 {
|
||||
if split.AccountId != importedAccount.AccountId {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print("Imported split's AccountId wasn't -1 but also didn't match the account")
|
||||
return
|
||||
}
|
||||
split.AccountId = account.AccountId
|
||||
} else if split.SecurityId != -1 {
|
||||
if sec, ok := securitymap[split.SecurityId]; ok {
|
||||
// TODO try to auto-match splits to existing accounts based on past transactions that look like this one
|
||||
if split.ImportSplitType == TradingAccount {
|
||||
// Find/make trading account if we're that type of split
|
||||
trading_account, err := GetTradingAccount(sqltransaction, user.UserId, sec.SecurityId)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print("Couldn't find split's SecurityId in map during OFX import")
|
||||
return
|
||||
}
|
||||
split.AccountId = trading_account.AccountId
|
||||
split.SecurityId = -1
|
||||
} else if split.ImportSplitType == SubAccount {
|
||||
subaccount := &Account{
|
||||
UserId: user.UserId,
|
||||
Name: sec.Name,
|
||||
ParentAccountId: account.AccountId,
|
||||
SecurityId: sec.SecurityId,
|
||||
Type: account.Type,
|
||||
}
|
||||
subaccount, err := GetCreateAccountTx(sqltransaction, *subaccount)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
split.AccountId = subaccount.AccountId
|
||||
split.SecurityId = -1
|
||||
} else {
|
||||
split.SecurityId = sec.SecurityId
|
||||
}
|
||||
} else {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print("Couldn't find split's SecurityId in map during OFX import")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print("Neither Split.AccountId Split.SecurityId was set during OFX import")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
imbalances, err := transaction.GetImbalancesTx(sqltransaction)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Fixup any imbalances in transactions
|
||||
var zero big.Rat
|
||||
for imbalanced_security, imbalance := range imbalances {
|
||||
if imbalance.Cmp(&zero) != 0 {
|
||||
imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, imbalanced_security)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Add new split to fixup imbalance
|
||||
split := new(Split)
|
||||
r := new(big.Rat)
|
||||
r.Neg(&imbalance)
|
||||
security, err := GetSecurityTx(sqltransaction, imbalanced_security, user.UserId)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
split.Amount = r.FloatString(security.Precision)
|
||||
split.SecurityId = -1
|
||||
split.AccountId = imbalanced_account.AccountId
|
||||
transaction.Splits = append(transaction.Splits, split)
|
||||
}
|
||||
}
|
||||
|
||||
// Move any splits with SecurityId but not AccountId to Imbalances
|
||||
// accounts. In the same loop, check to see if this transaction/split
|
||||
// has been imported before
|
||||
var already_imported bool
|
||||
for _, split := range transaction.Splits {
|
||||
if split.SecurityId != -1 || split.AccountId == -1 {
|
||||
imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, split.SecurityId)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
split.AccountId = imbalanced_account.AccountId
|
||||
split.SecurityId = -1
|
||||
}
|
||||
|
||||
exists, err := split.AlreadyImportedTx(sqltransaction)
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print("Error checking if split was already imported:", err)
|
||||
return
|
||||
} else if exists {
|
||||
already_imported = true
|
||||
}
|
||||
}
|
||||
|
||||
if !already_imported {
|
||||
transactions = append(transactions, transaction)
|
||||
}
|
||||
}
|
||||
|
||||
for _, transaction := range transactions {
|
||||
err := InsertTransactionTx(sqltransaction, &transaction, user)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = sqltransaction.Commit()
|
||||
if err != nil {
|
||||
sqltransaction.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
WriteSuccess(w)
|
||||
}
|
||||
|
||||
func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
||||
download_json := r.PostFormValue("ofxdownload")
|
||||
if download_json == "" {
|
||||
log.Print("download_json")
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var ofxdownload OFXDownload
|
||||
err := ofxdownload.Read(download_json)
|
||||
if err != nil {
|
||||
log.Print("ofxdownload.Read")
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := GetAccount(db, accountid, user.UserId)
|
||||
if err != nil {
|
||||
log.Print("GetAccount")
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
ofxver := ofxgo.OfxVersion203
|
||||
if len(account.OFXVersion) != 0 {
|
||||
ofxver, err = ofxgo.NewOfxVersion(account.OFXVersion)
|
||||
if err != nil {
|
||||
log.Print("NewOfxVersion")
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var client = ofxgo.Client{
|
||||
AppID: account.OFXAppID,
|
||||
AppVer: account.OFXAppVer,
|
||||
SpecVersion: ofxver,
|
||||
NoIndent: account.OFXNoIndent,
|
||||
}
|
||||
|
||||
var query ofxgo.Request
|
||||
query.URL = account.OFXURL
|
||||
query.Signon.ClientUID = ofxgo.UID(account.OFXClientUID)
|
||||
query.Signon.UserID = ofxgo.String(account.OFXUser)
|
||||
query.Signon.UserPass = ofxgo.String(ofxdownload.OFXPassword)
|
||||
query.Signon.Org = ofxgo.String(account.OFXORG)
|
||||
query.Signon.Fid = ofxgo.String(account.OFXFID)
|
||||
|
||||
transactionuid, err := ofxgo.RandomUID()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Println("Error creating uid for transaction:", err)
|
||||
return
|
||||
}
|
||||
|
||||
if account.Type == Investment {
|
||||
// Investment account
|
||||
statementRequest := ofxgo.InvStatementRequest{
|
||||
TrnUID: *transactionuid,
|
||||
InvAcctFrom: ofxgo.InvAcct{
|
||||
BrokerID: ofxgo.String(account.OFXBankID),
|
||||
AcctID: ofxgo.String(account.OFXAcctID),
|
||||
},
|
||||
Include: true,
|
||||
IncludeOO: true,
|
||||
IncludePos: true,
|
||||
IncludeBalance: true,
|
||||
Include401K: true,
|
||||
Include401KBal: true,
|
||||
}
|
||||
query.InvStmt = append(query.InvStmt, &statementRequest)
|
||||
} else if account.OFXAcctType == "CC" {
|
||||
// Import credit card transactions
|
||||
statementRequest := ofxgo.CCStatementRequest{
|
||||
TrnUID: *transactionuid,
|
||||
CCAcctFrom: ofxgo.CCAcct{
|
||||
AcctID: ofxgo.String(account.OFXAcctID),
|
||||
},
|
||||
Include: true,
|
||||
}
|
||||
query.CreditCard = append(query.CreditCard, &statementRequest)
|
||||
} else {
|
||||
// Import generic bank transactions
|
||||
acctTypeEnum, err := ofxgo.NewAcctType(account.OFXAcctType)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
statementRequest := ofxgo.StatementRequest{
|
||||
TrnUID: *transactionuid,
|
||||
BankAcctFrom: ofxgo.BankAcct{
|
||||
BankID: ofxgo.String(account.OFXBankID),
|
||||
AcctID: ofxgo.String(account.OFXAcctID),
|
||||
AcctType: acctTypeEnum,
|
||||
},
|
||||
Include: true,
|
||||
}
|
||||
query.Bank = append(query.Bank, &statementRequest)
|
||||
}
|
||||
|
||||
response, err := client.RequestNoParse(&query)
|
||||
if err != nil {
|
||||
// TODO this could be an error talking with the OFX server...
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
ofxImportHelper(db, response.Body, w, user, accountid)
|
||||
}
|
||||
|
||||
func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
||||
multipartReader, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
// assume there is only one 'part'
|
||||
part, err := multipartReader.NextPart()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
log.Print("Encountered unexpected EOF")
|
||||
} else {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ofxImportHelper(db, part, w, user, accountid)
|
||||
}
|
||||
|
||||
/*
|
||||
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
|
||||
*/
|
||||
func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
|
||||
|
||||
switch importtype {
|
||||
case "ofx":
|
||||
OFXImportHandler(db, w, r, user, accountid)
|
||||
case "ofxfile":
|
||||
OFXFileImportHandler(db, w, r, user, accountid)
|
||||
default:
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
}
|
||||
}
|
1003
internal/handlers/ofx.go
Normal file
1003
internal/handlers/ofx.go
Normal file
File diff suppressed because it is too large
Load Diff
113
internal/handlers/prices.go
Normal file
113
internal/handlers/prices.go
Normal file
@ -0,0 +1,113 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gopkg.in/gorp.v1"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Price struct {
|
||||
PriceId int64
|
||||
SecurityId int64
|
||||
CurrencyId int64
|
||||
Date time.Time
|
||||
Value string // String representation of decimal price of Security in Currency units, suitable for passing to big.Rat.SetString()
|
||||
RemoteId string // unique ID from source, for detecting duplicates
|
||||
}
|
||||
|
||||
func InsertPriceTx(transaction *gorp.Transaction, p *Price) error {
|
||||
err := transaction.Insert(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error {
|
||||
if len(price.RemoteId) == 0 {
|
||||
// Always create a new price if we can't match on the RemoteId
|
||||
err := InsertPriceTx(transaction, price)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var prices []*Price
|
||||
|
||||
_, err := transaction.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(prices) > 0 {
|
||||
return nil // price already exists
|
||||
}
|
||||
|
||||
err = InsertPriceTx(transaction, price)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the latest price for security in currency units before date
|
||||
func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) {
|
||||
var p Price
|
||||
err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Return the earliest price for security in currency units after date
|
||||
func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) {
|
||||
var p Price
|
||||
err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Return the price for security in currency closest to date
|
||||
func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) {
|
||||
earliest, _ := GetEarliestPrice(transaction, security, currency, date)
|
||||
latest, err := GetLatestPrice(transaction, security, currency, date)
|
||||
|
||||
// Return early if either earliest or latest are invalid
|
||||
if earliest == nil {
|
||||
return latest, err
|
||||
} else if err != nil {
|
||||
return earliest, nil
|
||||
}
|
||||
|
||||
howlate := earliest.Date.Sub(*date)
|
||||
howearly := date.Sub(latest.Date)
|
||||
if howearly < howlate {
|
||||
return latest, nil
|
||||
} else {
|
||||
return earliest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
price, err := GetClosestPriceTx(transaction, security, currency, date)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return price, nil
|
||||
}
|
91
internal/handlers/prices_lua.go
Normal file
91
internal/handlers/prices_lua.go
Normal file
@ -0,0 +1,91 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const luaPriceTypeName = "price"
|
||||
|
||||
func luaRegisterPrices(L *lua.LState) {
|
||||
mt := L.NewTypeMetatable(luaPriceTypeName)
|
||||
L.SetGlobal("price", mt)
|
||||
L.SetField(mt, "__index", L.NewFunction(luaPrice__index))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(luaPrice__tostring))
|
||||
L.SetField(mt, "__metatable", lua.LString("protected"))
|
||||
}
|
||||
|
||||
func PriceToLua(L *lua.LState, price *Price) *lua.LUserData {
|
||||
ud := L.NewUserData()
|
||||
ud.Value = price
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaPriceTypeName))
|
||||
return ud
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Price and returns this *Price.
|
||||
func luaCheckPrice(L *lua.LState, n int) *Price {
|
||||
ud := L.CheckUserData(n)
|
||||
if price, ok := ud.Value.(*Price); ok {
|
||||
return price
|
||||
}
|
||||
L.ArgError(n, "price expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaPrice__index(L *lua.LState) int {
|
||||
p := luaCheckPrice(L, 1)
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "PriceId", "priceid":
|
||||
L.Push(lua.LNumber(float64(p.PriceId)))
|
||||
case "Security", "security":
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
panic("luaContextGetSecurities couldn't fetch securities")
|
||||
}
|
||||
s, ok := security_map[p.SecurityId]
|
||||
if !ok {
|
||||
panic("Price's security not found for user")
|
||||
}
|
||||
L.Push(SecurityToLua(L, s))
|
||||
case "Currency", "currency":
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
panic("luaContextGetSecurities couldn't fetch securities")
|
||||
}
|
||||
c, ok := security_map[p.CurrencyId]
|
||||
if !ok {
|
||||
panic("Price's currency not found for user")
|
||||
}
|
||||
L.Push(SecurityToLua(L, c))
|
||||
case "Value", "value":
|
||||
amt, err := GetBigAmount(p.Value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
float, _ := amt.Float64()
|
||||
L.Push(lua.LNumber(float))
|
||||
default:
|
||||
L.ArgError(2, "unexpected price attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaPrice__tostring(L *lua.LState) int {
|
||||
p := luaCheckPrice(L, 1)
|
||||
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
panic("luaContextGetSecurities couldn't fetch securities")
|
||||
}
|
||||
s, ok1 := security_map[p.SecurityId]
|
||||
c, ok2 := security_map[p.CurrencyId]
|
||||
if !ok1 || !ok2 {
|
||||
panic("Price's currency or security not found for user")
|
||||
}
|
||||
|
||||
L.Push(lua.LString(p.Value + " " + c.Symbol + " (" + s.Symbol + ")"))
|
||||
|
||||
return 1
|
||||
}
|
354
internal/handlers/reports.go
Normal file
354
internal/handlers/reports.go
Normal file
@ -0,0 +1,354 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/yuin/gopher-lua"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var reportTabulationRE *regexp.Regexp
|
||||
|
||||
func init() {
|
||||
reportTabulationRE = regexp.MustCompile(`^/report/[0-9]+/tabulation/?$`)
|
||||
}
|
||||
|
||||
//type and value to store user in lua's Context
|
||||
type key int
|
||||
|
||||
const (
|
||||
userContextKey key = iota
|
||||
accountsContextKey
|
||||
securitiesContextKey
|
||||
balanceContextKey
|
||||
dbContextKey
|
||||
)
|
||||
|
||||
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
||||
|
||||
type Report struct {
|
||||
ReportId int64
|
||||
UserId int64
|
||||
Name string
|
||||
Lua string
|
||||
}
|
||||
|
||||
func (r *Report) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(r)
|
||||
}
|
||||
|
||||
func (r *Report) Read(json_str string) error {
|
||||
dec := json.NewDecoder(strings.NewReader(json_str))
|
||||
return dec.Decode(r)
|
||||
}
|
||||
|
||||
type ReportList struct {
|
||||
Reports *[]Report `json:"reports"`
|
||||
}
|
||||
|
||||
func (rl *ReportList) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(rl)
|
||||
}
|
||||
|
||||
type Series struct {
|
||||
Values []float64
|
||||
Series map[string]*Series
|
||||
}
|
||||
|
||||
type Tabulation struct {
|
||||
ReportId int64
|
||||
Title string
|
||||
Subtitle string
|
||||
Units string
|
||||
Labels []string
|
||||
Series map[string]*Series
|
||||
}
|
||||
|
||||
func (r *Tabulation) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(r)
|
||||
}
|
||||
|
||||
func GetReport(db *DB, reportid int64, userid int64) (*Report, error) {
|
||||
var r Report
|
||||
|
||||
err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func GetReports(db *DB, userid int64) (*[]Report, error) {
|
||||
var reports []Report
|
||||
|
||||
_, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &reports, nil
|
||||
}
|
||||
|
||||
func InsertReport(db *DB, r *Report) error {
|
||||
err := db.Insert(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateReport(db *DB, r *Report) error {
|
||||
count, err := db.Update(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one report")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteReport(db *DB, r *Report) error {
|
||||
count, err := db.Delete(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Deleted more than one report")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runReport(db *DB, user *User, report *Report) (*Tabulation, error) {
|
||||
// Create a new LState without opening the default libs for security
|
||||
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||
defer L.Close()
|
||||
|
||||
// Create a new context holding the current user with a timeout
|
||||
ctx := context.WithValue(context.Background(), userContextKey, user)
|
||||
ctx = context.WithValue(ctx, dbContextKey, db)
|
||||
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
|
||||
defer cancel()
|
||||
L.SetContext(ctx)
|
||||
|
||||
for _, pair := range []struct {
|
||||
n string
|
||||
f lua.LGFunction
|
||||
}{
|
||||
{lua.LoadLibName, lua.OpenPackage}, // Must be first
|
||||
{lua.BaseLibName, lua.OpenBase},
|
||||
{lua.TabLibName, lua.OpenTable},
|
||||
{lua.StringLibName, lua.OpenString},
|
||||
{lua.MathLibName, lua.OpenMath},
|
||||
} {
|
||||
if err := L.CallByParam(lua.P{
|
||||
Fn: L.NewFunction(pair.f),
|
||||
NRet: 0,
|
||||
Protect: true,
|
||||
}, lua.LString(pair.n)); err != nil {
|
||||
return nil, errors.New("Error initializing Lua packages")
|
||||
}
|
||||
}
|
||||
|
||||
luaRegisterAccounts(L)
|
||||
luaRegisterSecurities(L)
|
||||
luaRegisterBalances(L)
|
||||
luaRegisterDates(L)
|
||||
luaRegisterTabulations(L)
|
||||
luaRegisterPrices(L)
|
||||
|
||||
err := L.DoString(report.Lua)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := L.CallByParam(lua.P{
|
||||
Fn: L.GetGlobal("generate"),
|
||||
NRet: 1,
|
||||
Protect: true,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value := L.Get(-1)
|
||||
if ud, ok := value.(*lua.LUserData); ok {
|
||||
if tabulation, ok := ud.Value.(*Tabulation); ok {
|
||||
return tabulation, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("generate() for %s (Id: %d) didn't return a tabulation", report.Name, report.ReportId)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("generate() for %s (Id: %d) didn't even return LUserData", report.Name, report.ReportId)
|
||||
}
|
||||
}
|
||||
|
||||
func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
|
||||
report, err := GetReport(db, reportid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
tabulation, err := runReport(db, user, report)
|
||||
if err != nil {
|
||||
// TODO handle different failure cases differently
|
||||
log.Print("runReport returned:", err)
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
tabulation.ReportId = reportid
|
||||
|
||||
err = tabulation.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "POST" {
|
||||
report_json := r.PostFormValue("report")
|
||||
if report_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var report Report
|
||||
err := report.Read(report_json)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
report.ReportId = -1
|
||||
report.UserId = user.UserId
|
||||
|
||||
err = InsertReport(db, &report)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(201 /*Created*/)
|
||||
err = report.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "GET" {
|
||||
if reportTabulationRE.MatchString(r.URL.Path) {
|
||||
var reportid int64
|
||||
n, err := GetURLPieces(r.URL.Path, "/report/%d/tabulation", &reportid)
|
||||
if err != nil || n != 1 {
|
||||
WriteError(w, 999 /*InternalError*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
ReportTabulationHandler(db, w, r, user, reportid)
|
||||
return
|
||||
}
|
||||
|
||||
var reportid int64
|
||||
n, err := GetURLPieces(r.URL.Path, "/report/%d", &reportid)
|
||||
if err != nil || n != 1 {
|
||||
//Return all Reports
|
||||
var rl ReportList
|
||||
reports, err := GetReports(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
rl.Reports = reports
|
||||
err = (&rl).Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Return Report with this Id
|
||||
report, err := GetReport(db, reportid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = report.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
reportid, err := GetURLID(r.URL.Path)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "PUT" {
|
||||
report_json := r.PostFormValue("report")
|
||||
if report_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var report Report
|
||||
err := report.Read(report_json)
|
||||
if err != nil || report.ReportId != reportid {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
report.UserId = user.UserId
|
||||
|
||||
err = UpdateReport(db, &report)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = report.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
report, err := GetReport(db, reportid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteReport(db, report)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
WriteSuccess(w)
|
||||
}
|
||||
}
|
||||
}
|
42
internal/handlers/reports/asset_allocation.lua
Normal file
42
internal/handlers/reports/asset_allocation.lua
Normal file
@ -0,0 +1,42 @@
|
||||
function generate()
|
||||
accounts = get_accounts()
|
||||
securities = get_securities()
|
||||
default_currency = get_default_currency()
|
||||
series_map = {}
|
||||
totals_map = {}
|
||||
|
||||
t = tabulation.new(1)
|
||||
t:title("Current Asset Allocation")
|
||||
|
||||
t:label(1, "Assets")
|
||||
|
||||
for id, security in pairs(securities) do
|
||||
totals_map[id] = 0
|
||||
series_map[id] = t:series(tostring(security))
|
||||
end
|
||||
|
||||
for id, acct in pairs(accounts) do
|
||||
if acct.type == account.Asset or acct.type == account.Investment or acct.type == account.Bank or acct.type == account.Cash then
|
||||
balance = acct:balance()
|
||||
multiplier = 1
|
||||
if acct.security ~= default_currency and balance.amount ~= 0 then
|
||||
price = acct.security:closestprice(default_currency, date.now())
|
||||
if price == nil then
|
||||
--[[
|
||||
-- This should contain code to warn the user that their report is missing some information
|
||||
--]]
|
||||
multiplier = 0
|
||||
else
|
||||
multiplier = price.value
|
||||
end
|
||||
end
|
||||
totals_map[acct.security.SecurityId] = balance.amount * multiplier + totals_map[acct.security.SecurityId]
|
||||
end
|
||||
end
|
||||
|
||||
for id, series in pairs(series_map) do
|
||||
series:value(1, totals_map[id])
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
26
internal/handlers/reports/monthly_cash_flow.lua
Normal file
26
internal/handlers/reports/monthly_cash_flow.lua
Normal file
@ -0,0 +1,26 @@
|
||||
function generate()
|
||||
year = date.now().year
|
||||
|
||||
accounts = get_accounts()
|
||||
t = tabulation.new(12)
|
||||
t:title(year .. " Monthly Cash Flow")
|
||||
series = t:series("Income minus expenses")
|
||||
|
||||
for month=1,12 do
|
||||
begin_date = date.new(year, month, 1)
|
||||
end_date = date.new(year, month+1, 1)
|
||||
|
||||
t:label(month, tostring(begin_date))
|
||||
cash_flow = 0
|
||||
|
||||
for id, acct in pairs(accounts) do
|
||||
if acct.type == account.Expense or acct.type == account.Income then
|
||||
balance = acct:balance(begin_date, end_date)
|
||||
cash_flow = cash_flow - balance.amount
|
||||
end
|
||||
end
|
||||
series:value(month, cash_flow)
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
49
internal/handlers/reports/monthly_expenses.lua
Normal file
49
internal/handlers/reports/monthly_expenses.lua
Normal file
@ -0,0 +1,49 @@
|
||||
function account_series_map(accounts, tabulation)
|
||||
map = {}
|
||||
|
||||
for i=1,100 do -- we're not messing with accounts more than 100 levels deep
|
||||
all_handled = true
|
||||
for id, acct in pairs(accounts) do
|
||||
if not map[id] then
|
||||
all_handled = false
|
||||
if not acct.parent then
|
||||
map[id] = tabulation:series(acct.name)
|
||||
elseif map[acct.parent.accountid] then
|
||||
map[id] = map[acct.parent.accountid]:series(acct.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
if all_handled then
|
||||
return map
|
||||
end
|
||||
end
|
||||
|
||||
error("Accounts nested (at least) 100 levels deep")
|
||||
end
|
||||
|
||||
function generate()
|
||||
year = date.now().year
|
||||
account_type = account.Expense
|
||||
|
||||
accounts = get_accounts()
|
||||
t = tabulation.new(12)
|
||||
t:title(year .. " Monthly Expenses")
|
||||
series_map = account_series_map(accounts, t)
|
||||
|
||||
for month=1,12 do
|
||||
begin_date = date.new(year, month, 1)
|
||||
end_date = date.new(year, month+1, 1)
|
||||
|
||||
t:label(month, tostring(begin_date))
|
||||
|
||||
for id, acct in pairs(accounts) do
|
||||
series = series_map[id]
|
||||
if acct.type == account_type then
|
||||
balance = acct:balance(begin_date, end_date)
|
||||
series:value(month, balance.amount)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
60
internal/handlers/reports/monthly_net_worth.lua
Normal file
60
internal/handlers/reports/monthly_net_worth.lua
Normal file
@ -0,0 +1,60 @@
|
||||
function account_series_map(accounts, tabulation)
|
||||
map = {}
|
||||
|
||||
for i=1,100 do -- we're not messing with accounts more than 100 levels deep
|
||||
all_handled = true
|
||||
for id, acct in pairs(accounts) do
|
||||
if not map[id] then
|
||||
all_handled = false
|
||||
if not acct.parent then
|
||||
map[id] = tabulation:series(acct.name)
|
||||
elseif map[acct.parent.accountid] then
|
||||
map[id] = map[acct.parent.accountid]:series(acct.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
if all_handled then
|
||||
return map
|
||||
end
|
||||
end
|
||||
|
||||
error("Accounts nested (at least) 100 levels deep")
|
||||
end
|
||||
|
||||
function generate()
|
||||
year = date.now().year
|
||||
|
||||
accounts = get_accounts()
|
||||
t = tabulation.new(12)
|
||||
t:title(year .. " Monthly Net Worth")
|
||||
series_map = account_series_map(accounts, t)
|
||||
default_currency = get_default_currency()
|
||||
|
||||
for month=1,12 do
|
||||
end_date = date.new(year, month+1, 1)
|
||||
|
||||
t:label(month, tostring(end_date))
|
||||
|
||||
for id, acct in pairs(accounts) do
|
||||
series = series_map[id]
|
||||
if acct.type ~= account.Expense and acct.type ~= account.Income and acct.type ~= account.Trading then
|
||||
balance = acct:balance(end_date)
|
||||
multiplier = 1
|
||||
if acct.security ~= default_currency and balance.amount ~= 0 then
|
||||
price = acct.security:closestprice(default_currency, end_date)
|
||||
if price == nil then
|
||||
--[[
|
||||
-- This should contain code to warn the user that their report is missing some information
|
||||
--]]
|
||||
multiplier = 0
|
||||
else
|
||||
multiplier = price.value
|
||||
end
|
||||
end
|
||||
series:value(month, balance.amount * multiplier)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
61
internal/handlers/reports/monthly_net_worth_change.lua
Normal file
61
internal/handlers/reports/monthly_net_worth_change.lua
Normal file
@ -0,0 +1,61 @@
|
||||
function account_series_map(accounts, tabulation)
|
||||
map = {}
|
||||
|
||||
for i=1,100 do -- we're not messing with accounts more than 100 levels deep
|
||||
all_handled = true
|
||||
for id, acct in pairs(accounts) do
|
||||
if not map[id] then
|
||||
all_handled = false
|
||||
if not acct.parent then
|
||||
map[id] = tabulation:series(acct.name)
|
||||
elseif map[acct.parent.accountid] then
|
||||
map[id] = map[acct.parent.accountid]:series(acct.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
if all_handled then
|
||||
return map
|
||||
end
|
||||
end
|
||||
|
||||
error("Accounts nested (at least) 100 levels deep")
|
||||
end
|
||||
|
||||
function generate()
|
||||
year = date.now().year
|
||||
|
||||
accounts = get_accounts()
|
||||
t = tabulation.new(12)
|
||||
t:title(year .. " Monthly Net Worth")
|
||||
series_map = account_series_map(accounts, t)
|
||||
default_currency = get_default_currency()
|
||||
|
||||
for month=1,12 do
|
||||
begin_date = date.new(year, month, 1)
|
||||
end_date = date.new(year, month+1, 1)
|
||||
|
||||
t:label(month, tostring(begin_date))
|
||||
|
||||
for id, acct in pairs(accounts) do
|
||||
series = series_map[id]
|
||||
if acct.type ~= account.Expense and acct.type ~= account.Income and acct.type ~= account.Trading then
|
||||
balance = acct:balance(begin_date, end_date)
|
||||
multiplier = 1
|
||||
if acct.security ~= default_currency then
|
||||
price = acct.security:closestprice(default_currency, end_date)
|
||||
if price == nil then
|
||||
--[[
|
||||
-- This should contain code to warn the user that their report is missing some information
|
||||
--]]
|
||||
multiplier = 0
|
||||
else
|
||||
multiplier = price.value
|
||||
end
|
||||
end
|
||||
series:value(month, balance.amount * multiplier)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
60
internal/handlers/reports/quarterly_net_worth.lua
Normal file
60
internal/handlers/reports/quarterly_net_worth.lua
Normal file
@ -0,0 +1,60 @@
|
||||
function account_series_map(accounts, tabulation)
|
||||
map = {}
|
||||
|
||||
for i=1,100 do -- we're not messing with accounts more than 100 levels deep
|
||||
all_handled = true
|
||||
for id, acct in pairs(accounts) do
|
||||
if not map[id] then
|
||||
all_handled = false
|
||||
if not acct.parent then
|
||||
map[id] = tabulation:series(acct.name)
|
||||
elseif map[acct.parent.accountid] then
|
||||
map[id] = map[acct.parent.accountid]:series(acct.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
if all_handled then
|
||||
return map
|
||||
end
|
||||
end
|
||||
|
||||
error("Accounts nested (at least) 100 levels deep")
|
||||
end
|
||||
|
||||
function generate()
|
||||
year = date.now().year-4
|
||||
|
||||
accounts = get_accounts()
|
||||
t = tabulation.new(20)
|
||||
t:title(year .. "-" .. date.now().year .. " Quarterly Net Worth")
|
||||
series_map = account_series_map(accounts, t:series("Net Worth"))
|
||||
default_currency = get_default_currency()
|
||||
|
||||
for month=1,20 do
|
||||
end_date = date.new(year, month*3-2, 1)
|
||||
|
||||
t:label(month, tostring(end_date))
|
||||
|
||||
for id, acct in pairs(accounts) do
|
||||
series = series_map[id]
|
||||
if acct.type ~= account.Expense and acct.type ~= account.Income and acct.type ~= account.Trading then
|
||||
balance = acct:balance(end_date)
|
||||
multiplier = 1
|
||||
if acct.security ~= default_currency then
|
||||
price = acct.security:closestprice(default_currency, end_date)
|
||||
if price == nil then
|
||||
--[[
|
||||
-- This should contain code to warn the user that their report is missing some information
|
||||
--]]
|
||||
multiplier = 0
|
||||
else
|
||||
multiplier = price.value
|
||||
end
|
||||
end
|
||||
series:value(month, balance.amount * multiplier)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
47
internal/handlers/reports/years_income.lua
Normal file
47
internal/handlers/reports/years_income.lua
Normal file
@ -0,0 +1,47 @@
|
||||
function account_series_map(accounts, tabulation)
|
||||
map = {}
|
||||
|
||||
for i=1,100 do -- we're not messing with accounts more than 100 levels deep
|
||||
all_handled = true
|
||||
for id, acct in pairs(accounts) do
|
||||
if not map[id] then
|
||||
all_handled = false
|
||||
if not acct.parent then
|
||||
map[id] = tabulation:series(acct.name)
|
||||
elseif map[acct.parent.accountid] then
|
||||
map[id] = map[acct.parent.accountid]:series(acct.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
if all_handled then
|
||||
return map
|
||||
end
|
||||
end
|
||||
|
||||
error("Accounts nested (at least) 100 levels deep")
|
||||
end
|
||||
|
||||
function generate()
|
||||
year = date.now().year
|
||||
account_type = account.Income
|
||||
|
||||
accounts = get_accounts()
|
||||
t = tabulation.new(1)
|
||||
t:title(year .. " Income")
|
||||
series_map = account_series_map(accounts, t)
|
||||
|
||||
begin_date = date.new(year, 1, 1)
|
||||
end_date = date.new(year+1, 1, 1)
|
||||
|
||||
t:label(1, year .. " Income")
|
||||
|
||||
for id, acct in pairs(accounts) do
|
||||
series = series_map[id]
|
||||
if acct.type == account_type then
|
||||
balance = acct:balance(begin_date, end_date)
|
||||
series:value(1, balance.amount)
|
||||
end
|
||||
end
|
||||
|
||||
return t
|
||||
end
|
187
internal/handlers/reports_lua.go
Normal file
187
internal/handlers/reports_lua.go
Normal file
@ -0,0 +1,187 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const luaTabulationTypeName = "tabulation"
|
||||
const luaSeriesTypeName = "series"
|
||||
|
||||
func luaRegisterTabulations(L *lua.LState) {
|
||||
mtr := L.NewTypeMetatable(luaTabulationTypeName)
|
||||
L.SetGlobal("tabulation", mtr)
|
||||
L.SetField(mtr, "new", L.NewFunction(luaTabulationNew))
|
||||
L.SetField(mtr, "__index", L.NewFunction(luaTabulation__index))
|
||||
L.SetField(mtr, "__metatable", lua.LString("protected"))
|
||||
|
||||
mts := L.NewTypeMetatable(luaSeriesTypeName)
|
||||
L.SetGlobal("series", mts)
|
||||
L.SetField(mts, "__index", L.NewFunction(luaSeries__index))
|
||||
L.SetField(mts, "__metatable", lua.LString("protected"))
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Tabulation and returns *Tabulation
|
||||
func luaCheckTabulation(L *lua.LState, n int) *Tabulation {
|
||||
ud := L.CheckUserData(n)
|
||||
if tabulation, ok := ud.Value.(*Tabulation); ok {
|
||||
return tabulation
|
||||
}
|
||||
L.ArgError(n, "tabulation expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Series and returns *Series
|
||||
func luaCheckSeries(L *lua.LState, n int) *Series {
|
||||
ud := L.CheckUserData(n)
|
||||
if series, ok := ud.Value.(*Series); ok {
|
||||
return series
|
||||
}
|
||||
L.ArgError(n, "series expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaTabulationNew(L *lua.LState) int {
|
||||
numvalues := L.CheckInt(1)
|
||||
ud := L.NewUserData()
|
||||
ud.Value = &Tabulation{
|
||||
Labels: make([]string, numvalues),
|
||||
Series: make(map[string]*Series),
|
||||
}
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaTabulationTypeName))
|
||||
L.Push(ud)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaTabulation__index(L *lua.LState) int {
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "Label", "label":
|
||||
L.Push(L.NewFunction(luaTabulationLabel))
|
||||
case "Series", "series":
|
||||
L.Push(L.NewFunction(luaTabulationSeries))
|
||||
case "Title", "title":
|
||||
L.Push(L.NewFunction(luaTabulationTitle))
|
||||
case "Subtitle", "subtitle":
|
||||
L.Push(L.NewFunction(luaTabulationSubtitle))
|
||||
case "Units", "units":
|
||||
L.Push(L.NewFunction(luaTabulationUnits))
|
||||
default:
|
||||
L.ArgError(2, "unexpected tabulation attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaTabulationLabel(L *lua.LState) int {
|
||||
tabulation := luaCheckTabulation(L, 1)
|
||||
labelnumber := L.CheckInt(2)
|
||||
label := L.CheckString(3)
|
||||
|
||||
if labelnumber > cap(tabulation.Labels) || labelnumber < 1 {
|
||||
L.ArgError(2, "Label index must be between 1 and the number of data points, inclusive")
|
||||
}
|
||||
tabulation.Labels[labelnumber-1] = label
|
||||
return 0
|
||||
}
|
||||
|
||||
func luaTabulationSeries(L *lua.LState) int {
|
||||
tabulation := luaCheckTabulation(L, 1)
|
||||
name := L.CheckString(2)
|
||||
ud := L.NewUserData()
|
||||
|
||||
s, ok := tabulation.Series[name]
|
||||
if ok {
|
||||
ud.Value = s
|
||||
} else {
|
||||
tabulation.Series[name] = &Series{
|
||||
Series: make(map[string]*Series),
|
||||
Values: make([]float64, cap(tabulation.Labels)),
|
||||
}
|
||||
ud.Value = tabulation.Series[name]
|
||||
}
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName))
|
||||
L.Push(ud)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaTabulationTitle(L *lua.LState) int {
|
||||
tabulation := luaCheckTabulation(L, 1)
|
||||
|
||||
if L.GetTop() == 2 {
|
||||
tabulation.Title = L.CheckString(2)
|
||||
return 0
|
||||
}
|
||||
L.Push(lua.LString(tabulation.Title))
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaTabulationSubtitle(L *lua.LState) int {
|
||||
tabulation := luaCheckTabulation(L, 1)
|
||||
|
||||
if L.GetTop() == 2 {
|
||||
tabulation.Subtitle = L.CheckString(2)
|
||||
return 0
|
||||
}
|
||||
L.Push(lua.LString(tabulation.Subtitle))
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaTabulationUnits(L *lua.LState) int {
|
||||
tabulation := luaCheckTabulation(L, 1)
|
||||
|
||||
if L.GetTop() == 2 {
|
||||
tabulation.Units = L.CheckString(2)
|
||||
return 0
|
||||
}
|
||||
L.Push(lua.LString(tabulation.Units))
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaSeries__index(L *lua.LState) int {
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "Value", "value":
|
||||
L.Push(L.NewFunction(luaSeriesValue))
|
||||
case "Series", "series":
|
||||
L.Push(L.NewFunction(luaSeriesSeries))
|
||||
default:
|
||||
L.ArgError(2, "unexpected series attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaSeriesValue(L *lua.LState) int {
|
||||
series := luaCheckSeries(L, 1)
|
||||
valuenumber := L.CheckInt(2)
|
||||
value := float64(L.CheckNumber(3))
|
||||
|
||||
if valuenumber > cap(series.Values) || valuenumber < 1 {
|
||||
L.ArgError(2, "value index must be between 1 and the number of data points, inclusive")
|
||||
}
|
||||
series.Values[valuenumber-1] = value
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func luaSeriesSeries(L *lua.LState) int {
|
||||
parent := luaCheckSeries(L, 1)
|
||||
name := L.CheckString(2)
|
||||
ud := L.NewUserData()
|
||||
|
||||
s, ok := parent.Series[name]
|
||||
if ok {
|
||||
ud.Value = s
|
||||
} else {
|
||||
parent.Series[name] = &Series{
|
||||
Series: make(map[string]*Series),
|
||||
Values: make([]float64, cap(parent.Values)),
|
||||
}
|
||||
ud.Value = parent.Series[name]
|
||||
}
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName))
|
||||
L.Push(ud)
|
||||
return 1
|
||||
}
|
38
internal/handlers/scripts/gen_cusip_csv.sh
Executable file
38
internal/handlers/scripts/gen_cusip_csv.sh
Executable file
@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
QUARTER=2017q1
|
||||
|
||||
function get_ticker() {
|
||||
local cusip=$1
|
||||
|
||||
local tmpfile=$tmpdir/curl_tmpfile
|
||||
curl -s -d "sopt=cusip&tickersymbol=${cusip}" http://quantumonline.com/search.cfm > $tmpfile
|
||||
local quantum_name=$(sed -rn 's@<font size="\+1"><center><b>(.+)</b><br></center></font>\s*$@\1@p' $tmpfile | head -n1)
|
||||
local quantum_ticker=$(sed -rn 's@^.*Ticker Symbol: ([A-Z\.0-9\-]+) CUSIP.*$@\1@p' $tmpfile | head -n1)
|
||||
|
||||
if [[ -z $quantum_ticker ]] || [[ -z $quantum_name ]]; then
|
||||
curl -s -d "reqforlookup=REQUESTFORLOOKUP&productid=mmnet&isLoggedIn=mmnet&rows=50&for=stock&by=cusip&criteria=${cusip}&submit=Search" http://quotes.fidelity.com/mmnet/SymLookup.phtml > $tmpfile
|
||||
fidelity_name=$(sed -rn 's@<tr><td height="20" nowrap><font class="smallfont">(.+)</font></td>\s*@\1@p' $tmpfile | sed -r 's/\&/\&/')
|
||||
fidelity_ticker=$(sed -rn 's@\s+<td align="center" width="20%"><font><a href="/webxpress/get_quote\?QUOTE_TYPE=\&SID_VALUE_ID=(.+)">(.+)</a></td>\s*@\1@p' $tmpfile | head -n1)
|
||||
if [[ -z $fidelity_ticker ]] || [[ -z $fidelity_name ]]; then
|
||||
echo $cusip >> $tmpdir/${QUARTER}_bad_cusips.csv
|
||||
else
|
||||
echo "$cusip,$fidelity_ticker,$fidelity_name"
|
||||
fi
|
||||
else
|
||||
echo "$cusip,$quantum_ticker,$quantum_name"
|
||||
fi
|
||||
}
|
||||
|
||||
tmpdir=$(mktemp -d -p $PWD)
|
||||
|
||||
# Get the list of CUSIPs from the SEC and generate a nicer format of it
|
||||
wget -q http://www.sec.gov/divisions/investment/13f/13flist${QUARTER}.pdf -O $tmpdir/13flist${QUARTER}.pdf
|
||||
pdftotext -layout $tmpdir/13flist${QUARTER}.pdf - > $tmpdir/13flist${QUARTER}.txt
|
||||
sed -rn 's/^([A-Z0-9]{6}) ([A-Z0-9]{2}) ([A-Z0-9]) .*$/\1\2\3/p' $tmpdir/13flist${QUARTER}.txt > $tmpdir/${QUARTER}_cusips
|
||||
|
||||
# Find tickers and names for all the CUSIPs we can and print them out
|
||||
for cusip in $(cat $tmpdir/${QUARTER}_cusips); do
|
||||
get_ticker $cusip
|
||||
done
|
||||
|
||||
rm -rf $tmpdir
|
114
internal/handlers/scripts/gen_security_list.py
Executable file
114
internal/handlers/scripts/gen_security_list.py
Executable file
@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import csv
|
||||
from xml.dom import minidom
|
||||
import sys
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
from urllib2 import urlopen
|
||||
|
||||
# Allow writing utf-8 to stdout
|
||||
import codecs
|
||||
UTF8Writer = codecs.getwriter('utf8')
|
||||
sys.stdout = UTF8Writer(sys.stdout)
|
||||
else:
|
||||
from urllib.request import urlopen
|
||||
|
||||
# This is absent, but also unneeded in python3, so just return the string
|
||||
def unicode(s, encoding):
|
||||
return s
|
||||
|
||||
class Security(object):
|
||||
def __init__(self, name, description, number, _type, precision):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.number = number
|
||||
self.type = _type
|
||||
self.precision = precision
|
||||
def unicode(self):
|
||||
s = """\tSecurity{
|
||||
\t\tName: \"%s\",
|
||||
\t\tDescription: \"%s\",
|
||||
\t\tSymbol: \"%s\",
|
||||
\t\tPrecision: %d,
|
||||
\t\tType: %s,
|
||||
\t\tAlternateId: \"%s\"},\n""" % (self.name, self.description, self.name, self.precision, self.type, str(self.number))
|
||||
try:
|
||||
return unicode(s, 'utf_8')
|
||||
except TypeError:
|
||||
return s
|
||||
|
||||
class SecurityList(object):
|
||||
def __init__(self, comment):
|
||||
self.comment = comment
|
||||
self.currencies = {}
|
||||
def add(self, currency):
|
||||
self.currencies[currency.number] = currency
|
||||
def unicode(self):
|
||||
string = "\t// "+self.comment+"\n"
|
||||
for key in sorted(self.currencies.keys()):
|
||||
string += self.currencies[key].unicode()
|
||||
return string
|
||||
|
||||
def process_ccyntry(currency_list, node):
|
||||
name = ""
|
||||
nameSet = False
|
||||
number = 0
|
||||
numberSet = False
|
||||
description = ""
|
||||
precision = 0
|
||||
for n in node.childNodes:
|
||||
if n.nodeName == "Ccy":
|
||||
name = n.firstChild.nodeValue
|
||||
nameSet = True
|
||||
elif n.nodeName == "CcyNm":
|
||||
description = n.firstChild.nodeValue
|
||||
elif n.nodeName == "CcyNbr":
|
||||
number = int(n.firstChild.nodeValue)
|
||||
numberSet = True
|
||||
elif n.nodeName == "CcyMnrUnts":
|
||||
if n.firstChild.nodeValue == "N.A.":
|
||||
precision = 0
|
||||
else:
|
||||
precision = int(n.firstChild.nodeValue)
|
||||
if nameSet and numberSet:
|
||||
currency_list.add(Security(name, description, number, "Currency", precision))
|
||||
|
||||
def get_currency_list():
|
||||
currency_list = SecurityList("ISO 4217, from http://www.currency-iso.org/en/home/tables/table-a1.html")
|
||||
|
||||
f = urlopen('http://www.currency-iso.org/dam/downloads/lists/list_one.xml')
|
||||
xmldoc = minidom.parse(f)
|
||||
for isonode in xmldoc.childNodes:
|
||||
if isonode.nodeName == "ISO_4217":
|
||||
for ccytblnode in isonode.childNodes:
|
||||
if ccytblnode.nodeName == "CcyTbl":
|
||||
for ccyntrynode in ccytblnode.childNodes:
|
||||
if ccyntrynode.nodeName == "CcyNtry":
|
||||
process_ccyntry(currency_list, ccyntrynode)
|
||||
f.close()
|
||||
return currency_list
|
||||
|
||||
def get_cusip_list(filename):
|
||||
cusip_list = SecurityList("")
|
||||
with open(filename) as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter=',')
|
||||
for row in csvreader:
|
||||
cusip = row[0]
|
||||
name = row[1]
|
||||
description = ",".join(row[2:])
|
||||
cusip_list.add(Security(name, description, cusip, "Stock", 5))
|
||||
return cusip_list
|
||||
|
||||
def main():
|
||||
currency_list = get_currency_list()
|
||||
cusip_list = get_cusip_list('cusip_list.csv')
|
||||
|
||||
print("package handlers\n")
|
||||
print("var SecurityTemplates = []Security{")
|
||||
print(currency_list.unicode())
|
||||
print(cusip_list.unicode())
|
||||
print("}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
442
internal/handlers/securities.go
Normal file
442
internal/handlers/securities.go
Normal file
@ -0,0 +1,442 @@
|
||||
package handlers
|
||||
|
||||
//go:generate make
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"gopkg.in/gorp.v1"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
Currency int64 = 1
|
||||
Stock = 2
|
||||
)
|
||||
|
||||
func GetSecurityType(typestring string) int64 {
|
||||
if strings.EqualFold(typestring, "currency") {
|
||||
return Currency
|
||||
} else if strings.EqualFold(typestring, "stock") {
|
||||
return Stock
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
type Security struct {
|
||||
SecurityId int64
|
||||
UserId int64
|
||||
Name string
|
||||
Description string
|
||||
Symbol string
|
||||
// Number of decimal digits (to the right of the decimal point) this
|
||||
// security is precise to
|
||||
Precision int
|
||||
Type int64
|
||||
// AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency
|
||||
AlternateId string
|
||||
}
|
||||
|
||||
type SecurityList struct {
|
||||
Securities *[]*Security `json:"securities"`
|
||||
}
|
||||
|
||||
func (s *Security) Read(json_str string) error {
|
||||
dec := json.NewDecoder(strings.NewReader(json_str))
|
||||
return dec.Decode(s)
|
||||
}
|
||||
|
||||
func (s *Security) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(s)
|
||||
}
|
||||
|
||||
func (sl *SecurityList) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(sl)
|
||||
}
|
||||
|
||||
func SearchSecurityTemplates(search string, _type int64, limit int64) []*Security {
|
||||
upperSearch := strings.ToUpper(search)
|
||||
var results []*Security
|
||||
for i, security := range SecurityTemplates {
|
||||
if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
|
||||
strings.Contains(strings.ToUpper(security.Description), upperSearch) ||
|
||||
strings.Contains(strings.ToUpper(security.Symbol), upperSearch) {
|
||||
if _type == 0 || _type == security.Type {
|
||||
results = append(results, &SecurityTemplates[i])
|
||||
if limit != -1 && int64(len(results)) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func FindSecurityTemplate(name string, _type int64) *Security {
|
||||
for _, security := range SecurityTemplates {
|
||||
if name == security.Name && _type == security.Type {
|
||||
return &security
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindCurrencyTemplate(iso4217 int64) *Security {
|
||||
iso4217string := strconv.FormatInt(iso4217, 10)
|
||||
for _, security := range SecurityTemplates {
|
||||
if security.Type == Currency && security.AlternateId == iso4217string {
|
||||
return &security
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) {
|
||||
var s Security
|
||||
|
||||
err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64) (*Security, error) {
|
||||
var s Security
|
||||
|
||||
err := transaction.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func GetSecurities(db *DB, userid int64) (*[]*Security, error) {
|
||||
var securities []*Security
|
||||
|
||||
_, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &securities, nil
|
||||
}
|
||||
|
||||
func InsertSecurity(db *DB, s *Security) error {
|
||||
err := db.Insert(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error {
|
||||
err := transaction.Insert(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateSecurity(db *DB, s *Security) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := GetUserTx(transaction, s.UserId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
} else if user.DefaultCurrency == s.SecurityId && s.Type != Currency {
|
||||
transaction.Rollback()
|
||||
return errors.New("Cannot change security which is user's default currency to be non-currency")
|
||||
}
|
||||
|
||||
count, err := transaction.Update(s)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Updated more than one security")
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteSecurity(db *DB, s *Security) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// First, ensure no accounts are using this security
|
||||
accounts, err := transaction.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
||||
|
||||
if accounts != 0 {
|
||||
transaction.Rollback()
|
||||
return errors.New("One or more accounts still use this security")
|
||||
}
|
||||
|
||||
user, err := GetUserTx(transaction, s.UserId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
} else if user.DefaultCurrency == s.SecurityId {
|
||||
transaction.Rollback()
|
||||
return errors.New("Cannot delete security which is user's default currency")
|
||||
}
|
||||
|
||||
// Remove all prices involving this security (either of this security, or
|
||||
// using it as a currency)
|
||||
_, err = transaction.Exec("DELETE * FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := transaction.Delete(s)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Deleted more than one security")
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, security *Security) (*Security, error) {
|
||||
security.UserId = userid
|
||||
if len(security.AlternateId) == 0 {
|
||||
// Always create a new local security if we can't match on the AlternateId
|
||||
err := InsertSecurityTx(transaction, security)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return security, nil
|
||||
}
|
||||
|
||||
var securities []*Security
|
||||
|
||||
_, err := transaction.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// First try to find a case insensitive match on the name or symbol
|
||||
upperName := strings.ToUpper(security.Name)
|
||||
upperSymbol := strings.ToUpper(security.Symbol)
|
||||
for _, s := range securities {
|
||||
if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
|
||||
(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
// if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
|
||||
|
||||
// Try to find a partial string match on the name or symbol
|
||||
for _, s := range securities {
|
||||
sUpperName := strings.ToUpper(s.Name)
|
||||
sUpperSymbol := strings.ToUpper(s.Symbol)
|
||||
if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
|
||||
(len(upperSymbol) > 0 && len(s.Symbol) > 0 && (strings.Contains(upperSymbol, sUpperSymbol) || strings.Contains(sUpperSymbol, upperSymbol))) {
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Give up and return the first security in the list
|
||||
if len(securities) > 0 {
|
||||
return securities[0], nil
|
||||
}
|
||||
|
||||
// If there wasn't even one security in the list, make a new one
|
||||
err = InsertSecurityTx(transaction, security)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return security, nil
|
||||
}
|
||||
|
||||
func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "POST" {
|
||||
security_json := r.PostFormValue("security")
|
||||
if security_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var security Security
|
||||
err := security.Read(security_json)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
security.SecurityId = -1
|
||||
security.UserId = user.UserId
|
||||
|
||||
err = InsertSecurity(db, &security)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(201 /*Created*/)
|
||||
err = security.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "GET" {
|
||||
var securityid int64
|
||||
n, err := GetURLPieces(r.URL.Path, "/security/%d", &securityid)
|
||||
|
||||
if err != nil || n != 1 {
|
||||
//Return all securities
|
||||
var sl SecurityList
|
||||
|
||||
securities, err := GetSecurities(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
sl.Securities = securities
|
||||
err = (&sl).Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
security, err := GetSecurity(db, securityid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = security.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
securityid, err := GetURLID(r.URL.Path)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
if r.Method == "PUT" {
|
||||
security_json := r.PostFormValue("security")
|
||||
if security_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var security Security
|
||||
err := security.Read(security_json)
|
||||
if err != nil || security.SecurityId != securityid {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
security.UserId = user.UserId
|
||||
|
||||
err = UpdateSecurity(db, &security)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = security.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
security, err := GetSecurity(db, securityid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteSecurity(db, security)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
WriteSuccess(w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "GET" {
|
||||
var sl SecurityList
|
||||
|
||||
query, _ := url.ParseQuery(r.URL.RawQuery)
|
||||
|
||||
var limit int64 = -1
|
||||
search := query.Get("search")
|
||||
_type := GetSecurityType(query.Get("type"))
|
||||
|
||||
limitstring := query.Get("limit")
|
||||
if limitstring != "" {
|
||||
limitint, err := strconv.ParseInt(limitstring, 10, 0)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
limit = limitint
|
||||
}
|
||||
|
||||
securities := SearchSecurityTemplates(search, _type, limit)
|
||||
|
||||
sl.Securities = &securities
|
||||
err := (&sl).Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
}
|
||||
}
|
188
internal/handlers/securities_lua.go
Normal file
188
internal/handlers/securities_lua.go
Normal file
@ -0,0 +1,188 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const luaSecurityTypeName = "security"
|
||||
|
||||
func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
|
||||
var security_map map[int64]*Security
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find DB in lua's Context")
|
||||
}
|
||||
|
||||
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
|
||||
if !ok {
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
securities, err := GetSecurities(db, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
security_map = make(map[int64]*Security)
|
||||
for i := range *securities {
|
||||
security_map[(*securities)[i].SecurityId] = (*securities)[i]
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, securitiesContextKey, security_map)
|
||||
L.SetContext(ctx)
|
||||
}
|
||||
|
||||
return security_map, nil
|
||||
}
|
||||
|
||||
func luaContextGetDefaultCurrency(L *lua.LState) (*Security, error) {
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
if security, ok := security_map[user.DefaultCurrency]; ok {
|
||||
return security, nil
|
||||
} else {
|
||||
return nil, errors.New("DefaultCurrency not in lua security_map")
|
||||
}
|
||||
}
|
||||
|
||||
func luaGetDefaultCurrency(L *lua.LState) int {
|
||||
defcurrency, err := luaContextGetDefaultCurrency(L)
|
||||
if err != nil {
|
||||
panic("luaGetDefaultCurrency couldn't fetch default currency")
|
||||
}
|
||||
|
||||
L.Push(SecurityToLua(L, defcurrency))
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaGetSecurities(L *lua.LState) int {
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
panic("luaGetSecurities couldn't fetch securities")
|
||||
}
|
||||
|
||||
table := L.NewTable()
|
||||
|
||||
for securityid := range security_map {
|
||||
table.RawSetInt(int(securityid), SecurityToLua(L, security_map[securityid]))
|
||||
}
|
||||
|
||||
L.Push(table)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaRegisterSecurities(L *lua.LState) {
|
||||
mt := L.NewTypeMetatable(luaSecurityTypeName)
|
||||
L.SetGlobal("security", mt)
|
||||
L.SetField(mt, "__index", L.NewFunction(luaSecurity__index))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(luaSecurity__tostring))
|
||||
L.SetField(mt, "__eq", L.NewFunction(luaSecurity__eq))
|
||||
L.SetField(mt, "__metatable", lua.LString("protected"))
|
||||
getSecuritiesFn := L.NewFunction(luaGetSecurities)
|
||||
L.SetField(mt, "get_all", getSecuritiesFn)
|
||||
getDefaultCurrencyFn := L.NewFunction(luaGetDefaultCurrency)
|
||||
L.SetField(mt, "get_default", getDefaultCurrencyFn)
|
||||
|
||||
// also register the get_securities and get_default functions as globals in
|
||||
// their own right
|
||||
L.SetGlobal("get_securities", getSecuritiesFn)
|
||||
L.SetGlobal("get_default_currency", getDefaultCurrencyFn)
|
||||
}
|
||||
|
||||
func SecurityToLua(L *lua.LState, security *Security) *lua.LUserData {
|
||||
ud := L.NewUserData()
|
||||
ud.Value = security
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaSecurityTypeName))
|
||||
return ud
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Security and returns this *Security.
|
||||
func luaCheckSecurity(L *lua.LState, n int) *Security {
|
||||
ud := L.CheckUserData(n)
|
||||
if security, ok := ud.Value.(*Security); ok {
|
||||
return security
|
||||
}
|
||||
L.ArgError(n, "security expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaSecurity__index(L *lua.LState) int {
|
||||
a := luaCheckSecurity(L, 1)
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "SecurityId", "securityid":
|
||||
L.Push(lua.LNumber(float64(a.SecurityId)))
|
||||
case "Name", "name":
|
||||
L.Push(lua.LString(a.Name))
|
||||
case "Description", "description":
|
||||
L.Push(lua.LString(a.Description))
|
||||
case "Symbol", "symbol":
|
||||
L.Push(lua.LString(a.Symbol))
|
||||
case "Precision", "precision":
|
||||
L.Push(lua.LNumber(float64(a.Precision)))
|
||||
case "Type", "type":
|
||||
L.Push(lua.LNumber(float64(a.Type)))
|
||||
case "ClosestPrice", "closestprice":
|
||||
L.Push(L.NewFunction(luaClosestPrice))
|
||||
default:
|
||||
L.ArgError(2, "unexpected security attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaClosestPrice(L *lua.LState) int {
|
||||
s := luaCheckSecurity(L, 1)
|
||||
c := luaCheckSecurity(L, 2)
|
||||
date := luaCheckTime(L, 3)
|
||||
|
||||
ctx := L.Context()
|
||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||
if !ok {
|
||||
panic("Couldn't find DB in lua's Context")
|
||||
}
|
||||
|
||||
p, err := GetClosestPrice(db, s, c, date)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
} else {
|
||||
L.Push(PriceToLua(L, p))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaSecurity__tostring(L *lua.LState) int {
|
||||
s := luaCheckSecurity(L, 1)
|
||||
|
||||
L.Push(lua.LString(s.Name + " - " + s.Description + " (" + s.Symbol + ")"))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaSecurity__eq(L *lua.LState) int {
|
||||
a := luaCheckSecurity(L, 1)
|
||||
b := luaCheckSecurity(L, 2)
|
||||
|
||||
L.Push(lua.LBool(a.SecurityId == b.SecurityId))
|
||||
|
||||
return 1
|
||||
}
|
139
internal/handlers/sessions.go
Normal file
139
internal/handlers/sessions.go
Normal file
@ -0,0 +1,139 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
SessionId int64
|
||||
SessionSecret string `json:"-"`
|
||||
UserId int64
|
||||
}
|
||||
|
||||
func (s *Session) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(s)
|
||||
}
|
||||
|
||||
func GetSession(db *DB, r *http.Request) (*Session, error) {
|
||||
var s Session
|
||||
|
||||
cookie, err := r.Cookie("moneygo-session")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("moneygo-session cookie not set")
|
||||
}
|
||||
s.SessionSecret = cookie.Value
|
||||
|
||||
err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func DeleteSessionIfExists(db *DB, r *http.Request) {
|
||||
// TODO do this in one transaction
|
||||
session, err := GetSession(db, r)
|
||||
if err == nil {
|
||||
db.Delete(session)
|
||||
}
|
||||
}
|
||||
|
||||
func NewSessionCookie() (string, error) {
|
||||
bits := make([]byte, 128)
|
||||
if _, err := io.ReadFull(rand.Reader, bits); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(bits), nil
|
||||
}
|
||||
|
||||
func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
|
||||
s := Session{}
|
||||
|
||||
session_secret, err := NewSessionCookie()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookie := http.Cookie{
|
||||
Name: "moneygo-session",
|
||||
Value: session_secret,
|
||||
Path: "/",
|
||||
Domain: r.URL.Host,
|
||||
Expires: time.Now().AddDate(0, 1, 0), // a month from now
|
||||
Secure: true,
|
||||
HttpOnly: true,
|
||||
}
|
||||
http.SetCookie(w, &cookie)
|
||||
|
||||
s.SessionSecret = session_secret
|
||||
s.UserId = userid
|
||||
|
||||
err = db.Insert(&s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
if r.Method == "POST" || r.Method == "PUT" {
|
||||
user_json := r.PostFormValue("user")
|
||||
if user_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
user := User{}
|
||||
err := user.Read(user_json)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
dbuser, err := GetUserByUsername(db, user.Username)
|
||||
if err != nil {
|
||||
WriteError(w, 2 /*Unauthorized Access*/)
|
||||
return
|
||||
}
|
||||
|
||||
user.HashPassword()
|
||||
if user.PasswordHash != dbuser.PasswordHash {
|
||||
WriteError(w, 2 /*Unauthorized Access*/)
|
||||
return
|
||||
}
|
||||
|
||||
DeleteSessionIfExists(db, r)
|
||||
|
||||
session, err := NewSession(db, w, r, dbuser.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = session.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "GET" {
|
||||
s, err := GetSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
}
|
||||
|
||||
s.Write(w)
|
||||
} else if r.Method == "DELETE" {
|
||||
DeleteSessionIfExists(db, r)
|
||||
WriteSuccess(w)
|
||||
}
|
||||
}
|
945
internal/handlers/transactions.go
Normal file
945
internal/handlers/transactions.go
Normal file
@ -0,0 +1,945 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gopkg.in/gorp.v1"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Split.Status
|
||||
const (
|
||||
Imported int64 = 1
|
||||
Entered = 2
|
||||
Cleared = 3
|
||||
Reconciled = 4
|
||||
Voided = 5
|
||||
)
|
||||
|
||||
// Split.ImportSplitType
|
||||
const (
|
||||
Default int64 = 0
|
||||
ImportAccount = 1 // This split belongs to the main account being imported
|
||||
SubAccount = 2 // This split belongs to a sub-account of that being imported
|
||||
ExternalAccount = 3
|
||||
TradingAccount = 4
|
||||
Commission = 5
|
||||
Taxes = 6
|
||||
Fees = 7
|
||||
Load = 8
|
||||
IncomeAccount = 9
|
||||
ExpenseAccount = 10
|
||||
)
|
||||
|
||||
type Split struct {
|
||||
SplitId int64
|
||||
TransactionId int64
|
||||
Status int64
|
||||
ImportSplitType int64
|
||||
|
||||
// One of AccountId and SecurityId must be -1
|
||||
// In normal splits, AccountId will be valid and SecurityId will be -1. The
|
||||
// only case where this is reversed is for transactions that have been
|
||||
// imported and not yet associated with an account.
|
||||
AccountId int64
|
||||
SecurityId int64
|
||||
|
||||
RemoteId string // unique ID from server, for detecting duplicates
|
||||
Number string // Check or reference number
|
||||
Memo string
|
||||
Amount string // String representation of decimal, suitable for passing to big.Rat.SetString()
|
||||
}
|
||||
|
||||
func GetBigAmount(amt string) (*big.Rat, error) {
|
||||
var r big.Rat
|
||||
_, success := r.SetString(amt)
|
||||
if !success {
|
||||
return nil, errors.New("Couldn't convert string amount to big.Rat via SetString()")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (s *Split) GetAmount() (*big.Rat, error) {
|
||||
return GetBigAmount(s.Amount)
|
||||
}
|
||||
|
||||
func (s *Split) Valid() bool {
|
||||
if (s.AccountId == -1) == (s.SecurityId == -1) {
|
||||
return false
|
||||
}
|
||||
_, err := s.GetAmount()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (s *Split) AlreadyImportedTx(transaction *gorp.Transaction) (bool, error) {
|
||||
count, err := transaction.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
||||
return count == 1, err
|
||||
}
|
||||
|
||||
type Transaction struct {
|
||||
TransactionId int64
|
||||
UserId int64
|
||||
Description string
|
||||
Date time.Time
|
||||
Splits []*Split `db:"-"`
|
||||
}
|
||||
|
||||
type TransactionList struct {
|
||||
Transactions *[]Transaction `json:"transactions"`
|
||||
}
|
||||
|
||||
type AccountTransactionsList struct {
|
||||
Account *Account
|
||||
Transactions *[]Transaction
|
||||
TotalTransactions int64
|
||||
BeginningBalance string
|
||||
EndingBalance string
|
||||
}
|
||||
|
||||
func (t *Transaction) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(t)
|
||||
}
|
||||
|
||||
func (t *Transaction) Read(json_str string) error {
|
||||
dec := json.NewDecoder(strings.NewReader(json_str))
|
||||
return dec.Decode(t)
|
||||
}
|
||||
|
||||
func (tl *TransactionList) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(tl)
|
||||
}
|
||||
|
||||
func (atl *AccountTransactionsList) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(atl)
|
||||
}
|
||||
|
||||
func (t *Transaction) Valid() bool {
|
||||
for i := range t.Splits {
|
||||
if !t.Splits[i].Valid() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Return a map of security ID's to big.Rat's containing the amount that
|
||||
// security is imbalanced by
|
||||
func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]big.Rat, error) {
|
||||
sums := make(map[int64]big.Rat)
|
||||
|
||||
if !t.Valid() {
|
||||
return nil, errors.New("Transaction invalid")
|
||||
}
|
||||
|
||||
for i := range t.Splits {
|
||||
securityid := t.Splits[i].SecurityId
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
var err error
|
||||
var account *Account
|
||||
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
securityid = account.SecurityId
|
||||
}
|
||||
amount, _ := t.Splits[i].GetAmount()
|
||||
sum := sums[securityid]
|
||||
(&sum).Add(&sum, amount)
|
||||
sums[securityid] = sum
|
||||
}
|
||||
return sums, nil
|
||||
}
|
||||
|
||||
// Returns true if all securities contained in this transaction are balanced,
|
||||
// false otherwise
|
||||
func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
|
||||
var zero big.Rat
|
||||
|
||||
sums, err := t.GetImbalancesTx(transaction)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, security_sum := range sums {
|
||||
if security_sum.Cmp(&zero) != 0 {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) {
|
||||
var t Transaction
|
||||
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func GetTransactions(db *DB, userid int64) (*[]Transaction, error) {
|
||||
var transactions []Transaction
|
||||
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = transaction.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range transactions {
|
||||
_, err := transaction.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &transactions, nil
|
||||
}
|
||||
|
||||
func incrementAccountVersions(transaction *gorp.Transaction, user *User, accountids []int64) error {
|
||||
for i := range accountids {
|
||||
account, err := GetAccountTx(transaction, accountids[i], user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
account.AccountVersion++
|
||||
count, err := transaction.Update(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one account")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type AccountMissingError struct{}
|
||||
|
||||
func (ame AccountMissingError) Error() string {
|
||||
return "Account missing"
|
||||
}
|
||||
|
||||
func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
|
||||
// Map of any accounts with transaction splits being added
|
||||
a_map := make(map[int64]bool)
|
||||
for i := range t.Splits {
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != 1 {
|
||||
return AccountMissingError{}
|
||||
}
|
||||
a_map[t.Splits[i].AccountId] = true
|
||||
} else if t.Splits[i].SecurityId == -1 {
|
||||
return AccountMissingError{}
|
||||
}
|
||||
}
|
||||
|
||||
//increment versions for all accounts
|
||||
var a_ids []int64
|
||||
for id := range a_map {
|
||||
a_ids = append(a_ids, id)
|
||||
}
|
||||
// ensure at least one of the splits is associated with an actual account
|
||||
if len(a_ids) < 1 {
|
||||
return AccountMissingError{}
|
||||
}
|
||||
err := incrementAccountVersions(transaction, user, a_ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.UserId = user.UserId
|
||||
err = transaction.Insert(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range t.Splits {
|
||||
t.Splits[i].TransactionId = t.TransactionId
|
||||
t.Splits[i].SplitId = -1
|
||||
err = transaction.Insert(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertTransaction(db *DB, t *Transaction, user *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = InsertTransactionTx(transaction, t, user)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
|
||||
var existing_splits []*Split
|
||||
|
||||
_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Map of any accounts with transaction splits being added
|
||||
a_map := make(map[int64]bool)
|
||||
|
||||
// Make a map with any existing splits for this transaction
|
||||
s_map := make(map[int64]bool)
|
||||
for i := range existing_splits {
|
||||
s_map[existing_splits[i].SplitId] = true
|
||||
}
|
||||
|
||||
// Insert splits, updating any pre-existing ones
|
||||
for i := range t.Splits {
|
||||
t.Splits[i].TransactionId = t.TransactionId
|
||||
_, ok := s_map[t.Splits[i].SplitId]
|
||||
if ok {
|
||||
count, err := transaction.Update(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one transaction split")
|
||||
}
|
||||
delete(s_map, t.Splits[i].SplitId)
|
||||
} else {
|
||||
t.Splits[i].SplitId = -1
|
||||
err := transaction.Insert(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
a_map[t.Splits[i].AccountId] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Delete any remaining pre-existing splits
|
||||
for i := range existing_splits {
|
||||
_, ok := s_map[existing_splits[i].SplitId]
|
||||
if existing_splits[i].AccountId != -1 {
|
||||
a_map[existing_splits[i].AccountId] = true
|
||||
}
|
||||
if ok {
|
||||
_, err := transaction.Delete(existing_splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment versions for all accounts with modified splits
|
||||
var a_ids []int64
|
||||
for id := range a_map {
|
||||
a_ids = append(a_ids, id)
|
||||
}
|
||||
err = incrementAccountVersions(transaction, user, a_ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := transaction.Update(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one transaction")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteTransaction(db *DB, t *Transaction, user *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var accountids []int64
|
||||
_, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = transaction.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := transaction.Delete(t)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Deleted more than one transaction")
|
||||
}
|
||||
|
||||
err = incrementAccountVersions(transaction, user, accountids)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "POST" {
|
||||
transaction_json := r.PostFormValue("transaction")
|
||||
if transaction_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var transaction Transaction
|
||||
err := transaction.Read(transaction_json)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
transaction.TransactionId = -1
|
||||
transaction.UserId = user.UserId
|
||||
|
||||
sqltx, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
balanced, err := transaction.Balanced(sqltx)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
if !transaction.Valid() || !balanced {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
for i := range transaction.Splits {
|
||||
transaction.Splits[i].SplitId = -1
|
||||
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = InsertTransactionTx(sqltx, &transaction, user)
|
||||
if err != nil {
|
||||
if _, ok := err.(AccountMissingError); ok {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
} else {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
}
|
||||
sqltx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
err = sqltx.Commit()
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = transaction.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "GET" {
|
||||
transactionid, err := GetURLID(r.URL.Path)
|
||||
|
||||
if err != nil {
|
||||
//Return all Transactions
|
||||
var al TransactionList
|
||||
transactions, err := GetTransactions(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
al.Transactions = transactions
|
||||
err = (&al).Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
//Return Transaction with this Id
|
||||
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
err = transaction.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
transactionid, err := GetURLID(r.URL.Path)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
if r.Method == "PUT" {
|
||||
transaction_json := r.PostFormValue("transaction")
|
||||
if transaction_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var transaction Transaction
|
||||
err := transaction.Read(transaction_json)
|
||||
if err != nil || transaction.TransactionId != transactionid {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
transaction.UserId = user.UserId
|
||||
|
||||
sqltx, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
balanced, err := transaction.Balanced(sqltx)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
if !transaction.Valid() || !balanced {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
for i := range transaction.Splits {
|
||||
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = UpdateTransactionTx(sqltx, &transaction, user)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = sqltx.Commit()
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = transaction.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
transactionid, err := GetURLID(r.URL.Path)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteTransaction(db, transaction, user)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
WriteSuccess(w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) {
|
||||
var pageDifference, tmp big.Rat
|
||||
for i := range transactions {
|
||||
_, err := transaction.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sum up the amounts from the splits we're returning so we can return
|
||||
// an ending balance
|
||||
for j := range transactions[i].Splits {
|
||||
if transactions[i].Splits[j].AccountId == accountid {
|
||||
rat_amount, err := GetBigAmount(transactions[i].Splits[j].Amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&pageDifference, rat_amount)
|
||||
pageDifference.Set(&tmp)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &pageDifference, nil
|
||||
}
|
||||
|
||||
func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
||||
_, err = transaction.Select(&splits, sql, accountid, user.UserId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var balance, tmp big.Rat
|
||||
for _, s := range splits {
|
||||
rat_amount, err := GetBigAmount(s.Amount)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&balance, rat_amount)
|
||||
balance.Set(&tmp)
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
// Assumes accountid is valid and is owned by the current user
|
||||
func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
||||
_, err = transaction.Select(&splits, sql, accountid, user.UserId, date)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var balance, tmp big.Rat
|
||||
for _, s := range splits {
|
||||
rat_amount, err := GetBigAmount(s.Amount)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&balance, rat_amount)
|
||||
balance.Set(&tmp)
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
||||
_, err = transaction.Select(&splits, sql, accountid, user.UserId, begin, end)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var balance, tmp big.Rat
|
||||
for _, s := range splits {
|
||||
rat_amount, err := GetBigAmount(s.Amount)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&balance, rat_amount)
|
||||
balance.Set(&tmp)
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
||||
var transactions []Transaction
|
||||
var atl AccountTransactionsList
|
||||
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sqlsort, balanceLimitOffset string
|
||||
var balanceLimitOffsetArg uint64
|
||||
if sort == "date-asc" {
|
||||
sqlsort = " ORDER BY transactions.Date ASC"
|
||||
balanceLimitOffset = " LIMIT ?"
|
||||
balanceLimitOffsetArg = page * limit
|
||||
} else if sort == "date-desc" {
|
||||
numSplits, err := transaction.SelectInt("SELECT count(*) FROM splits")
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
sqlsort = " ORDER BY transactions.Date DESC"
|
||||
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
|
||||
balanceLimitOffsetArg = (page + 1) * limit
|
||||
}
|
||||
|
||||
var sqloffset string
|
||||
if page > 0 {
|
||||
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
|
||||
}
|
||||
|
||||
account, err := GetAccountTx(transaction, accountid, user.UserId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
atl.Account = account
|
||||
|
||||
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
|
||||
_, err = transaction.Select(&transactions, sql, user.UserId, accountid, limit)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
atl.Transactions = &transactions
|
||||
|
||||
pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count, err := transaction.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
atl.TotalTransactions = count
|
||||
|
||||
security, err := GetSecurityTx(transaction, atl.Account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
if security == nil {
|
||||
transaction.Rollback()
|
||||
return nil, errors.New("Security not found")
|
||||
}
|
||||
|
||||
// Sum all the splits for all transaction splits for this account that
|
||||
// occurred before the page we're returning
|
||||
var amounts []string
|
||||
sql = "SELECT splits.Amount FROM splits WHERE splits.AccountId=? AND splits.TransactionId IN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ")"
|
||||
_, err = transaction.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tmp, balance big.Rat
|
||||
for _, amount := range amounts {
|
||||
rat_amount, err := GetBigAmount(amount)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&balance, rat_amount)
|
||||
balance.Set(&tmp)
|
||||
}
|
||||
atl.BeginningBalance = balance.FloatString(security.Precision)
|
||||
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &atl, nil
|
||||
}
|
||||
|
||||
// Return only those transactions which have at least one split pertaining to
|
||||
// an account
|
||||
func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
|
||||
user *User, accountid int64) {
|
||||
|
||||
var page uint64 = 0
|
||||
var limit uint64 = 50
|
||||
var sort string = "date-desc"
|
||||
|
||||
query, _ := url.ParseQuery(r.URL.RawQuery)
|
||||
|
||||
pagestring := query.Get("page")
|
||||
if pagestring != "" {
|
||||
p, err := strconv.ParseUint(pagestring, 10, 0)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
page = p
|
||||
}
|
||||
|
||||
limitstring := query.Get("limit")
|
||||
if limitstring != "" {
|
||||
l, err := strconv.ParseUint(limitstring, 10, 0)
|
||||
if err != nil || l > 100 {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
limit = l
|
||||
}
|
||||
|
||||
sortstring := query.Get("sort")
|
||||
if sortstring != "" {
|
||||
if sortstring != "date-asc" && sortstring != "date-desc" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
sort = sortstring
|
||||
}
|
||||
|
||||
accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = accountTransactions.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
}
|
291
internal/handlers/users.go
Normal file
291
internal/handlers/users.go
Normal file
@ -0,0 +1,291 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gopkg.in/gorp.v1"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
UserId int64
|
||||
DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user
|
||||
Name string
|
||||
Username string
|
||||
Password string `db:"-"`
|
||||
PasswordHash string `json:"-"`
|
||||
Email string
|
||||
}
|
||||
|
||||
const BogusPassword = "password"
|
||||
|
||||
type UserExistsError struct{}
|
||||
|
||||
func (ueu UserExistsError) Error() string {
|
||||
return "User exists"
|
||||
}
|
||||
|
||||
func (u *User) Write(w http.ResponseWriter) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(u)
|
||||
}
|
||||
|
||||
func (u *User) Read(json_str string) error {
|
||||
dec := json.NewDecoder(strings.NewReader(json_str))
|
||||
return dec.Decode(u)
|
||||
}
|
||||
|
||||
func (u *User) HashPassword() {
|
||||
password_hasher := sha256.New()
|
||||
io.WriteString(password_hasher, u.Password)
|
||||
u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil))
|
||||
u.Password = ""
|
||||
}
|
||||
|
||||
func GetUser(db *DB, userid int64) (*User, error) {
|
||||
var u User
|
||||
|
||||
err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) {
|
||||
var u User
|
||||
|
||||
err := transaction.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func GetUserByUsername(db *DB, username string) (*User, error) {
|
||||
var u User
|
||||
|
||||
err := db.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func InsertUser(db *DB, u *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
||||
if security_template == nil {
|
||||
transaction.Rollback()
|
||||
return errors.New("Invalid ISO4217 Default Currency")
|
||||
}
|
||||
|
||||
existing, err := transaction.SelectInt("SELECT count(*) from users where Username=?", u.Username)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if existing > 0 {
|
||||
transaction.Rollback()
|
||||
return UserExistsError{}
|
||||
}
|
||||
|
||||
err = transaction.Insert(u)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy the security template and give it our new UserId
|
||||
var security Security
|
||||
security = *security_template
|
||||
security.UserId = u.UserId
|
||||
|
||||
err = InsertSecurityTx(transaction, &security)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the user's DefaultCurrency to our new SecurityId
|
||||
u.DefaultCurrency = security.SecurityId
|
||||
count, err := transaction.Update(u)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
} else if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Would have updated more than one user")
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUserFromSession(db *DB, r *http.Request) (*User, error) {
|
||||
s, err := GetSession(db, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return GetUser(db, s.UserId)
|
||||
}
|
||||
|
||||
func UpdateUser(db *DB, u *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
security, err := GetSecurityTx(transaction, u.DefaultCurrency, u.UserId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
|
||||
transaction.Rollback()
|
||||
return errors.New("UserId and DefaultCurrency don't match the fetched security")
|
||||
} else if security.Type != Currency {
|
||||
transaction.Rollback()
|
||||
return errors.New("New DefaultCurrency security is not a currency")
|
||||
}
|
||||
|
||||
count, err := transaction.Update(u)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
} else if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Would have updated more than one user")
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
if r.Method == "POST" {
|
||||
user_json := r.PostFormValue("user")
|
||||
if user_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
var user User
|
||||
err := user.Read(user_json)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
user.UserId = -1
|
||||
user.HashPassword()
|
||||
|
||||
err = InsertUser(db, &user)
|
||||
if err != nil {
|
||||
if _, ok := err.(UserExistsError); ok {
|
||||
WriteError(w, 4 /*User Exists*/)
|
||||
} else {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(201 /*Created*/)
|
||||
err = user.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
}
|
||||
|
||||
userid, err := GetURLID(r.URL.Path)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
if userid != user.UserId {
|
||||
WriteError(w, 2 /*Unauthorized Access*/)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "GET" {
|
||||
err = user.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "PUT" {
|
||||
user_json := r.PostFormValue("user")
|
||||
if user_json == "" {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
// Save old PWHash in case the new password is bogus
|
||||
old_pwhash := user.PasswordHash
|
||||
|
||||
err = user.Read(user_json)
|
||||
if err != nil || user.UserId != userid {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
// If the user didn't create a new password, keep their old one
|
||||
if user.Password != BogusPassword {
|
||||
user.HashPassword()
|
||||
} else {
|
||||
user.Password = ""
|
||||
user.PasswordHash = old_pwhash
|
||||
}
|
||||
|
||||
err = UpdateUser(db, user)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = user.Write(w)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
count, err := db.Delete(&user)
|
||||
if count != 1 || err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
WriteSuccess(w)
|
||||
}
|
||||
}
|
||||
}
|
23
internal/handlers/util.go
Normal file
23
internal/handlers/util.go
Normal file
@ -0,0 +1,23 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetURLID(url string) (int64, error) {
|
||||
pieces := strings.Split(strings.Trim(url, "/"), "/")
|
||||
return strconv.ParseInt(pieces[len(pieces)-1], 10, 0)
|
||||
}
|
||||
|
||||
func GetURLPieces(url string, format string, a ...interface{}) (int, error) {
|
||||
url = strings.Replace(url, "/", " ", -1)
|
||||
format = strings.Replace(format, "/", " ", -1)
|
||||
return fmt.Sscanf(url, format, a...)
|
||||
}
|
||||
|
||||
func WriteSuccess(w http.ResponseWriter) {
|
||||
fmt.Fprint(w, "{}")
|
||||
}
|
Reference in New Issue
Block a user