diff --git a/db.go b/db.go index eb2cf1b..942e876 100644 --- a/db.go +++ b/db.go @@ -22,6 +22,7 @@ func initDB() *gorp.DbMap { dbmap.AddTableWithName(Security{}, "securities").SetKeys(true, "SecurityId") dbmap.AddTableWithName(Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(Split{}, "splits").SetKeys(true, "SplitId") + dbmap.AddTableWithName(Price{}, "prices").SetKeys(true, "PriceId") dbmap.AddTableWithName(Report{}, "reports").SetKeys(true, "ReportId") err = dbmap.CreateTablesIfNotExists() diff --git a/docs/lua_reports.md b/docs/lua_reports.md index 4898111..72edf5b 100644 --- a/docs/lua_reports.md +++ b/docs/lua_reports.md @@ -136,9 +136,37 @@ has several fields describing it: * `s.Type` returns an int constant which represents what type of security it is (i.e. stock or currency) +Securities support a ClosestPrice function that allows you to fetch the price of +the current security in a given currency that is closest to the supplied date. +For example, to print the price in the user's default currency for each security +in the user's account: + +``` + default_currency = get_default_currency() + for id, security in pairs(get_securities()) do + price = security.price(default_currency, date.now()) + if price ~= nil then + print(tostring(security) .. ": " security.Symbol .. " " .. price.Value) + else + print("Failed to fetch price for " .. tostring(security)) + end + end +``` + You can also query for an account's default currency using the global `get_default_currency()` function. +### Prices + +Price objects can be queried from Security objects. Price objects contain the +following fields: + +* `p.PriceId` +* `p.Security` returns the security object the price is for +* `p.Currency` returns the currency that the price is in +* `p.Value` returns the price of one unit of 'security' in 'currency', as a + float + ### Dates In order to make it easier to do operations like finding account balances for a diff --git a/gnucash.go b/gnucash.go index 94391cf..1d719c1 100644 --- a/gnucash.go +++ b/gnucash.go @@ -72,6 +72,20 @@ type GnucashDate struct { Date GnucashTime `xml:"http://www.gnucash.org/XML/ts date"` } +type GnucashPrice struct { + Id string `xml:"http://www.gnucash.org/XML/price id"` + Commodity GnucashCommodity `xml:"http://www.gnucash.org/XML/price commodity"` + Currency GnucashCommodity `xml:"http://www.gnucash.org/XML/price currency"` + Date GnucashDate `xml:"http://www.gnucash.org/XML/price time"` + Source string `xml:"http://www.gnucash.org/XML/price source"` + Type string `xml:"http://www.gnucash.org/XML/price type"` + Value string `xml:"http://www.gnucash.org/XML/price value"` +} + +type GnucashPriceDB struct { + Prices []GnucashPrice `xml:"price"` +} + type GnucashAccount struct { Version string `xml:"version,attr"` accountid int64 // Used to map Gnucash guid's to integer ones @@ -105,6 +119,7 @@ type GnucashSplit struct { type GnucashXMLImport struct { XMLName xml.Name `xml:"gnc-v2"` Commodities []GnucashCommodity `xml:"http://www.gnucash.org/XML/gnc book>commodity"` + PriceDB GnucashPriceDB `xml:"http://www.gnucash.org/XML/gnc book>pricedb"` Accounts []GnucashAccount `xml:"http://www.gnucash.org/XML/gnc book>account"` Transactions []GnucashTransaction `xml:"http://www.gnucash.org/XML/gnc book>transaction"` } @@ -113,6 +128,7 @@ type GnucashImport struct { Securities []Security Accounts []Account Transactions []Transaction + Prices []Price } func ImportGnucash(r io.Reader) (*GnucashImport, error) { @@ -141,6 +157,38 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { } } + // Create prices, setting security and currency IDs from securityMap + for i := range gncxml.PriceDB.Prices { + price := gncxml.PriceDB.Prices[i] + var p Price + security, ok := securityMap[price.Commodity.Name] + if !ok { + return nil, fmt.Errorf("Unable to find commodity '%s' for price '%s'", price.Commodity.Name, price.Id) + } + currency, ok := securityMap[price.Currency.Name] + if !ok { + return nil, fmt.Errorf("Unable to find currency '%s' for price '%s'", price.Currency.Name, price.Id) + } + if currency.Type != Currency { + return nil, fmt.Errorf("Currency for imported price isn't actually a currency\n") + } + p.PriceId = int64(i + 1) + p.SecurityId = security.SecurityId + p.CurrencyId = currency.SecurityId + p.Date = price.Date.Date.Time + + var r big.Rat + _, ok = r.SetString(price.Value) + if ok { + p.Value = r.FloatString(currency.Precision) + } else { + return nil, fmt.Errorf("Can't set price value: %s", price.Value) + } + + p.RemoteId = "gnucash:" + price.Id + gncimport.Prices = append(gncimport.Prices, p) + } + //find root account, while simultaneously creating map of GUID's to //accounts var rootAccount GnucashAccount @@ -340,6 +388,21 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request) { securityMap[securityId] = s.SecurityId } + // Import prices, setting security and currency IDs from securityMap + for _, price := range gnucashImport.Prices { + price.SecurityId = securityMap[price.SecurityId] + price.CurrencyId = securityMap[price.CurrencyId] + price.PriceId = 0 + + err := CreatePriceIfNotExist(sqltransaction, &price) + if err != nil { + sqltransaction.Rollback() + WriteError(w, 6 /*Import Error*/) + log.Print(err) + return + } + } + // Get/create accounts in the database, building a map from Gnucash account // IDs to our internal IDs as we go accountMap := make(map[int64]int64) diff --git a/prices.go b/prices.go new file mode 100644 index 0000000..b067594 --- /dev/null +++ b/prices.go @@ -0,0 +1,122 @@ +package main + +import ( + "fmt" + "github.com/FlashBoys/go-finance" + "gopkg.in/gorp.v1" + "time" +) + +type Price struct { + PriceId int64 + SecurityId int64 + CurrencyId int64 + Date time.Time + Value string // String representation of decimal price of Security in Currency units, suitable for passing to big.Rat.SetString() + RemoteId string // unique ID from source, for detecting duplicates +} + +func InsertPriceTx(transaction *gorp.Transaction, p *Price) error { + err := transaction.Insert(p) + if err != nil { + return err + } + return nil +} + +func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { + if len(price.RemoteId) == 0 { + // Always create a new price if we can't match on the RemoteId + err := InsertPriceTx(transaction, price) + if err != nil { + return err + } + return nil + } + + var prices []*Price + + _, err := transaction.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) + if err != nil { + return err + } + + if len(prices) > 0 { + return nil // price already exists + } + + err = InsertPriceTx(transaction, price) + if err != nil { + return err + } + return nil +} + +// Return the latest price for security in currency units before date +func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { + var p Price + err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) + if err != nil { + return nil, err + } + return &p, nil +} + +// Return the earliest price for security in currency units after date +func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { + var p Price + err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) + if err != nil { + return nil, err + } + return &p, nil +} + +// Return the price for security in currency closest to date +func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { + earliest, _ := GetEarliestPrice(transaction, security, currency, date) + latest, err := GetLatestPrice(transaction, security, currency, date) + + // Return early if either earliest or latest are invalid + if earliest == nil { + return latest, err + } else if err != nil { + return earliest, nil + } + + howlate := earliest.Date.Sub(*date) + howearly := date.Sub(latest.Date) + if howearly < howlate { + return latest, nil + } else { + return earliest, nil + } +} + +func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, error) { + transaction, err := DB.Begin() + if err != nil { + return nil, err + } + + price, err := GetClosestPriceTx(transaction, security, currency, date) + if err != nil { + transaction.Rollback() + return nil, err + } + + err = transaction.Commit() + if err != nil { + transaction.Rollback() + return nil, err + } + + return price, nil +} + +func init() { + q, err := finance.GetQuote("BRK-A") + if err == nil { + fmt.Printf("%+v", q) + } +} diff --git a/prices_lua.go b/prices_lua.go new file mode 100644 index 0000000..9f763c6 --- /dev/null +++ b/prices_lua.go @@ -0,0 +1,91 @@ +package main + +import ( + "github.com/yuin/gopher-lua" +) + +const luaPriceTypeName = "price" + +func luaRegisterPrices(L *lua.LState) { + mt := L.NewTypeMetatable(luaPriceTypeName) + L.SetGlobal("price", mt) + L.SetField(mt, "__index", L.NewFunction(luaPrice__index)) + L.SetField(mt, "__tostring", L.NewFunction(luaPrice__tostring)) + L.SetField(mt, "__metatable", lua.LString("protected")) +} + +func PriceToLua(L *lua.LState, price *Price) *lua.LUserData { + ud := L.NewUserData() + ud.Value = price + L.SetMetatable(ud, L.GetTypeMetatable(luaPriceTypeName)) + return ud +} + +// Checks whether the first lua argument is a *LUserData with *Price and returns this *Price. +func luaCheckPrice(L *lua.LState, n int) *Price { + ud := L.CheckUserData(n) + if price, ok := ud.Value.(*Price); ok { + return price + } + L.ArgError(n, "price expected") + return nil +} + +func luaPrice__index(L *lua.LState) int { + p := luaCheckPrice(L, 1) + field := L.CheckString(2) + + switch field { + case "PriceId", "priceid": + L.Push(lua.LNumber(float64(p.PriceId))) + case "Security", "security": + security_map, err := luaContextGetSecurities(L) + if err != nil { + panic("luaContextGetSecurities couldn't fetch securities") + } + s, ok := security_map[p.SecurityId] + if !ok { + panic("Price's security not found for user") + } + L.Push(SecurityToLua(L, s)) + case "Currency", "currency": + security_map, err := luaContextGetSecurities(L) + if err != nil { + panic("luaContextGetSecurities couldn't fetch securities") + } + c, ok := security_map[p.CurrencyId] + if !ok { + panic("Price's currency not found for user") + } + L.Push(SecurityToLua(L, c)) + case "Value", "value": + amt, err := GetBigAmount(p.Value) + if err != nil { + panic(err) + } + float, _ := amt.Float64() + L.Push(lua.LNumber(float)) + default: + L.ArgError(2, "unexpected price attribute: "+field) + } + + return 1 +} + +func luaPrice__tostring(L *lua.LState) int { + p := luaCheckPrice(L, 1) + + security_map, err := luaContextGetSecurities(L) + if err != nil { + panic("luaContextGetSecurities couldn't fetch securities") + } + s, ok1 := security_map[p.SecurityId] + c, ok2 := security_map[p.CurrencyId] + if !ok1 || !ok2 { + panic("Price's currency or security not found for user") + } + + L.Push(lua.LString(p.Value + " " + c.Symbol + " (" + s.Symbol + ")")) + + return 1 +} diff --git a/reports.go b/reports.go index d84dc9d..328e068 100644 --- a/reports.go +++ b/reports.go @@ -161,6 +161,7 @@ func runReport(user *User, report *Report) (*Tabulation, error) { luaRegisterBalances(L) luaRegisterDates(L) luaRegisterTabulations(L) + luaRegisterPrices(L) err := L.DoString(report.Lua) diff --git a/securities_lua.go b/securities_lua.go index a2e323b..6cd769a 100644 --- a/securities_lua.go +++ b/securities_lua.go @@ -135,6 +135,8 @@ func luaSecurity__index(L *lua.LState) int { L.Push(lua.LNumber(float64(a.Precision))) case "Type", "type": L.Push(lua.LNumber(float64(a.Type))) + case "ClosestPrice", "closestprice": + L.Push(L.NewFunction(luaClosestPrice)) default: L.ArgError(2, "unexpected security attribute: "+field) } @@ -142,6 +144,21 @@ func luaSecurity__index(L *lua.LState) int { return 1 } +func luaClosestPrice(L *lua.LState) int { + s := luaCheckSecurity(L, 1) + c := luaCheckSecurity(L, 2) + date := luaCheckTime(L, 3) + + p, err := GetClosestPrice(s, c, date) + if err != nil { + L.Push(lua.LNil) + } else { + L.Push(PriceToLua(L, p)) + } + + return 1 +} + func luaSecurity__tostring(L *lua.LState) int { s := luaCheckSecurity(L, 1)