1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2024-10-31 16:00:05 -04:00

Merge pull request #36 from aclindsa/fix_amounts

Store currency/security values/prices using big.Rat natively
This commit is contained in:
Aaron Lindsay 2017-12-12 21:16:19 -05:00 committed by GitHub
commit d35893504b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 704 additions and 232 deletions

View File

@ -59,9 +59,11 @@ script:
- touch $GOPATH/src/github.com/aclindsa/moneygo/internal/handlers/cusip_list.csv - touch $GOPATH/src/github.com/aclindsa/moneygo/internal/handlers/cusip_list.csv
# Build and test MoneyGo # Build and test MoneyGo
- go generate -v github.com/aclindsa/moneygo/internal/handlers - go generate -v github.com/aclindsa/moneygo/internal/handlers
- go test -v -covermode=count -coverpkg github.com/aclindsa/moneygo/internal/config,github.com/aclindsa/moneygo/internal/handlers,github.com/aclindsa/moneygo/internal/models,github.com/aclindsa/moneygo/internal/store,github.com/aclindsa/moneygo/internal/store/db,github.com/aclindsa/moneygo/internal/reports -coverprofile=integration_coverage.out github.com/aclindsa/moneygo/internal/integration - export COVER_PACKAGES="github.com/aclindsa/moneygo/internal/config,github.com/aclindsa/moneygo/internal/handlers,github.com/aclindsa/moneygo/internal/models,github.com/aclindsa/moneygo/internal/store,github.com/aclindsa/moneygo/internal/store/db,github.com/aclindsa/moneygo/internal/reports"
- go test -v -covermode=count -coverpkg github.com/aclindsa/moneygo/internal/config,github.com/aclindsa/moneygo/internal/handlers,github.com/aclindsa/moneygo/internal/models,github.com/aclindsa/moneygo/internal/store,github.com/aclindsa/moneygo/internal/store/db,github.com/aclindsa/moneygo/internal/reports -coverprofile=config_coverage.out github.com/aclindsa/moneygo/internal/config - go test -v -covermode=count -coverpkg $COVER_PACKAGES -coverprofile=integration_coverage.out github.com/aclindsa/moneygo/internal/integration
- go test -v -covermode=count -coverpkg $COVER_PACKAGES -coverprofile=config_coverage.out github.com/aclindsa/moneygo/internal/config
- go test -v -covermode=count -coverpkg $COVER_PACKAGES -coverprofile=models_coverage.out github.com/aclindsa/moneygo/internal/models
# Report the test coverage # Report the test coverage
after_script: after_script:
- $GOPATH/bin/goveralls -coverprofile=integration_coverage.out,config_coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN - $GOPATH/bin/goveralls -coverprofile=integration_coverage.out,config_coverage.out,models_coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN

View File

@ -10,7 +10,6 @@ import (
"io" "io"
"log" "log"
"math" "math"
"math/big"
"net/http" "net/http"
"time" "time"
) )
@ -49,7 +48,7 @@ func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement)
gc.Precision = template.Precision gc.Precision = template.Precision
} else { } else {
if gxc.Fraction > 0 { if gxc.Fraction > 0 {
gc.Precision = int(math.Ceil(math.Log10(float64(gxc.Fraction)))) gc.Precision = uint64(math.Ceil(math.Log10(float64(gxc.Fraction))))
} else { } else {
gc.Precision = 0 gc.Precision = 0
} }
@ -178,13 +177,14 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) {
p.CurrencyId = currency.SecurityId p.CurrencyId = currency.SecurityId
p.Date = price.Date.Date.Time p.Date = price.Date.Date.Time
var r big.Rat _, ok = p.Value.SetString(price.Value)
_, ok = r.SetString(price.Value) if !ok {
if ok {
p.Value = r.FloatString(currency.Precision)
} else {
return nil, fmt.Errorf("Can't set price value: %s", price.Value) return nil, fmt.Errorf("Can't set price value: %s", price.Value)
} }
if p.Value.Precision() > currency.Precision {
// TODO we're possibly losing data here... but do we care?
p.Value.Round(currency.Precision)
}
p.RemoteId = "gnucash:" + price.Id p.RemoteId = "gnucash:" + price.Id
gncimport.Prices = append(gncimport.Prices, p) gncimport.Prices = append(gncimport.Prices, p)
@ -293,13 +293,13 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) {
s.Number = gt.Number s.Number = gt.Number
s.Memo = gs.Memo s.Memo = gs.Memo
var r big.Rat _, ok = s.Amount.SetString(gs.Amount)
_, ok = r.SetString(gs.Amount) if !ok {
if ok {
s.Amount = r.FloatString(security.Precision)
} else {
return nil, fmt.Errorf("Can't set split Amount: %s", gs.Amount) return nil, fmt.Errorf("Can't set split Amount: %s", gs.Amount)
} }
if s.Amount.Precision() > security.Precision {
return nil, fmt.Errorf("Imported price's precision (%d) is greater than the security's (%s)\n", s.Amount.Precision(), security)
}
t.Splits = append(t.Splits, s) t.Splits = append(t.Splits, s)
} }
@ -356,6 +356,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite
} }
if err != nil { if err != nil {
log.Print(err)
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }

View File

@ -164,7 +164,11 @@ func ofxImportHelper(tx store.Tx, r io.Reader, user *models.User, accountid int6
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
} }
split.Amount = r.FloatString(security.Precision) split.Amount.Rat = *r
if split.Amount.Precision() > security.Precision {
log.Printf("Precision on created imbalance-correction split (%d) greater than the underlying security (%s) allows (%d)", split.Amount.Precision(), security, security.Precision)
return NewError(999 /*Internal Error*/)
}
split.SecurityId = -1 split.SecurityId = -1
split.AccountId = imbalanced_account.AccountId split.AccountId = imbalanced_account.AccountId
transaction.Splits = append(transaction.Splits, split) transaction.Splits = append(transaction.Splits, split)

View File

@ -97,9 +97,12 @@ func (i *OFXImport) AddTransaction(tran *ofxgo.Transaction, account *models.Acco
s1.ImportSplitType = models.ImportAccount s1.ImportSplitType = models.ImportAccount
s2.ImportSplitType = models.ExternalAccount s2.ImportSplitType = models.ExternalAccount
s1.Amount.Rat = *amt
s2.Amount.Rat = *amt.Neg(amt)
security := i.Securities[account.SecurityId-1] security := i.Securities[account.SecurityId-1]
s1.Amount = amt.FloatString(security.Precision) if s1.Amount.Precision() > security.Precision {
s2.Amount = amt.Neg(amt).FloatString(security.Precision) return errors.New("Imported transaction amount is too precise for security")
}
s1.Status = models.Imported s1.Status = models.Imported
s2.Status = models.Imported s2.Status = models.Imported
@ -262,7 +265,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + buy.InvTran.FiTID.String(), RemoteId: "ofx:" + buy.InvTran.FiTID.String(),
Memo: memo + "(commission)", Memo: memo + "(commission)",
Amount: commission.FloatString(curdef.Precision), Amount: models.Amount{commission},
}) })
} }
if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -274,7 +277,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + buy.InvTran.FiTID.String(), RemoteId: "ofx:" + buy.InvTran.FiTID.String(),
Memo: memo + "(taxes)", Memo: memo + "(taxes)",
Amount: taxes.FloatString(curdef.Precision), Amount: models.Amount{taxes},
}) })
} }
if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -286,7 +289,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + buy.InvTran.FiTID.String(), RemoteId: "ofx:" + buy.InvTran.FiTID.String(),
Memo: memo + "(fees)", Memo: memo + "(fees)",
Amount: fees.FloatString(curdef.Precision), Amount: models.Amount{fees},
}) })
} }
if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { if num := load.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -298,7 +301,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + buy.InvTran.FiTID.String(), RemoteId: "ofx:" + buy.InvTran.FiTID.String(),
Memo: memo + "(load)", Memo: memo + "(load)",
Amount: load.FloatString(curdef.Precision), Amount: models.Amount{load},
}) })
} }
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -309,7 +312,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac
SecurityId: -1, SecurityId: -1,
RemoteId: "ofx:" + buy.InvTran.FiTID.String(), RemoteId: "ofx:" + buy.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
// TODO ReversalFiTID? // TODO ReversalFiTID?
@ -319,7 +322,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + buy.InvTran.FiTID.String(), RemoteId: "ofx:" + buy.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: tradingTotal.FloatString(curdef.Precision), Amount: models.Amount{tradingTotal},
}) })
var units big.Rat var units big.Rat
@ -332,7 +335,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac
SecurityId: security.SecurityId, SecurityId: security.SecurityId,
RemoteId: "ofx:" + buy.InvTran.FiTID.String(), RemoteId: "ofx:" + buy.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: units.FloatString(security.Precision), Amount: models.Amount{units},
}) })
units.Neg(&units) units.Neg(&units)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -343,7 +346,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac
SecurityId: security.SecurityId, SecurityId: security.SecurityId,
RemoteId: "ofx:" + buy.InvTran.FiTID.String(), RemoteId: "ofx:" + buy.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: units.FloatString(security.Precision), Amount: models.Amount{units},
}) })
return &t, nil return &t, nil
@ -378,7 +381,7 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security,
SecurityId: -1, SecurityId: -1,
RemoteId: "ofx:" + income.InvTran.FiTID.String(), RemoteId: "ofx:" + income.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
total.Neg(&total) total.Neg(&total)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -389,7 +392,7 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security,
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + income.InvTran.FiTID.String(), RemoteId: "ofx:" + income.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
return &t, nil return &t, nil
@ -423,7 +426,7 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models.
SecurityId: -1, SecurityId: -1,
RemoteId: "ofx:" + expense.InvTran.FiTID.String(), RemoteId: "ofx:" + expense.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
total.Neg(&total) total.Neg(&total)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -434,7 +437,7 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models.
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + expense.InvTran.FiTID.String(), RemoteId: "ofx:" + expense.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
return &t, nil return &t, nil
@ -462,7 +465,7 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde
SecurityId: -1, SecurityId: -1,
RemoteId: "ofx:" + marginint.InvTran.FiTID.String(), RemoteId: "ofx:" + marginint.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
total.Neg(&total) total.Neg(&total)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -473,7 +476,7 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + marginint.InvTran.FiTID.String(), RemoteId: "ofx:" + marginint.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
return &t, nil return &t, nil
@ -526,7 +529,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo + "(commission)", Memo: memo + "(commission)",
Amount: commission.FloatString(curdef.Precision), Amount: models.Amount{commission},
}) })
} }
if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -538,7 +541,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo + "(taxes)", Memo: memo + "(taxes)",
Amount: taxes.FloatString(curdef.Precision), Amount: models.Amount{taxes},
}) })
} }
if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -550,7 +553,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo + "(fees)", Memo: memo + "(fees)",
Amount: fees.FloatString(curdef.Precision), Amount: models.Amount{fees},
}) })
} }
if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { if num := load.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -562,7 +565,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo + "(load)", Memo: memo + "(load)",
Amount: load.FloatString(curdef.Precision), Amount: models.Amount{load},
}) })
} }
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -573,7 +576,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: -1, SecurityId: -1,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -584,7 +587,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
total.Neg(&total) total.Neg(&total)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -595,7 +598,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: -1, SecurityId: -1,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
// TODO ReversalFiTID? // TODO ReversalFiTID?
@ -605,7 +608,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: tradingTotal.FloatString(curdef.Precision), Amount: models.Amount{tradingTotal},
}) })
var units big.Rat var units big.Rat
@ -618,7 +621,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: security.SecurityId, SecurityId: security.SecurityId,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: units.FloatString(security.Precision), Amount: models.Amount{units},
}) })
units.Neg(&units) units.Neg(&units)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -629,7 +632,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec
SecurityId: security.SecurityId, SecurityId: security.SecurityId,
RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: units.FloatString(security.Precision), Amount: models.Amount{units},
}) })
return &t, nil return &t, nil
@ -663,7 +666,7 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Sec
SecurityId: -1, SecurityId: -1,
RemoteId: "ofx:" + retofcap.InvTran.FiTID.String(), RemoteId: "ofx:" + retofcap.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
total.Neg(&total) total.Neg(&total)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -674,7 +677,7 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Sec
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + retofcap.InvTran.FiTID.String(), RemoteId: "ofx:" + retofcap.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
return &t, nil return &t, nil
@ -730,7 +733,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security,
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + sell.InvTran.FiTID.String(), RemoteId: "ofx:" + sell.InvTran.FiTID.String(),
Memo: memo + "(commission)", Memo: memo + "(commission)",
Amount: commission.FloatString(curdef.Precision), Amount: models.Amount{commission},
}) })
} }
if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -742,7 +745,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security,
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + sell.InvTran.FiTID.String(), RemoteId: "ofx:" + sell.InvTran.FiTID.String(),
Memo: memo + "(taxes)", Memo: memo + "(taxes)",
Amount: taxes.FloatString(curdef.Precision), Amount: models.Amount{taxes},
}) })
} }
if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -754,7 +757,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security,
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + sell.InvTran.FiTID.String(), RemoteId: "ofx:" + sell.InvTran.FiTID.String(),
Memo: memo + "(fees)", Memo: memo + "(fees)",
Amount: fees.FloatString(curdef.Precision), Amount: models.Amount{fees},
}) })
} }
if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { if num := load.Num(); !num.IsInt64() || num.Int64() != 0 {
@ -766,7 +769,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security,
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + sell.InvTran.FiTID.String(), RemoteId: "ofx:" + sell.InvTran.FiTID.String(),
Memo: memo + "(load)", Memo: memo + "(load)",
Amount: load.FloatString(curdef.Precision), Amount: models.Amount{load},
}) })
} }
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -777,7 +780,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security,
SecurityId: -1, SecurityId: -1,
RemoteId: "ofx:" + sell.InvTran.FiTID.String(), RemoteId: "ofx:" + sell.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: total.FloatString(curdef.Precision), Amount: models.Amount{total},
}) })
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
// TODO ReversalFiTID? // TODO ReversalFiTID?
@ -787,7 +790,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security,
SecurityId: curdef.SecurityId, SecurityId: curdef.SecurityId,
RemoteId: "ofx:" + sell.InvTran.FiTID.String(), RemoteId: "ofx:" + sell.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: tradingTotal.FloatString(curdef.Precision), Amount: models.Amount{tradingTotal},
}) })
var units big.Rat var units big.Rat
@ -800,7 +803,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security,
SecurityId: security.SecurityId, SecurityId: security.SecurityId,
RemoteId: "ofx:" + sell.InvTran.FiTID.String(), RemoteId: "ofx:" + sell.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: units.FloatString(security.Precision), Amount: models.Amount{units},
}) })
units.Neg(&units) units.Neg(&units)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -811,7 +814,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security,
SecurityId: security.SecurityId, SecurityId: security.SecurityId,
RemoteId: "ofx:" + sell.InvTran.FiTID.String(), RemoteId: "ofx:" + sell.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: units.FloatString(security.Precision), Amount: models.Amount{units},
}) })
return &t, nil return &t, nil
@ -842,7 +845,7 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *models.Ac
SecurityId: security.SecurityId, SecurityId: security.SecurityId,
RemoteId: "ofx:" + transfer.InvTran.FiTID.String(), RemoteId: "ofx:" + transfer.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: units.FloatString(security.Precision), Amount: models.Amount{units},
}) })
units.Neg(&units) units.Neg(&units)
t.Splits = append(t.Splits, &models.Split{ t.Splits = append(t.Splits, &models.Split{
@ -853,7 +856,7 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *models.Ac
SecurityId: security.SecurityId, SecurityId: security.SecurityId,
RemoteId: "ofx:" + transfer.InvTran.FiTID.String(), RemoteId: "ofx:" + transfer.InvTran.FiTID.String(),
Memo: memo, Memo: memo,
Amount: units.FloatString(security.Precision), Amount: models.Amount{units},
}) })
return &t, nil return &t, nil

View File

@ -31,9 +31,8 @@ func GetTransactionImbalances(tx store.Tx, t *models.Transaction) (map[int64]big
} }
securityid = account.SecurityId securityid = account.SecurityId
} }
amount, _ := t.Splits[i].GetAmount()
sum := sums[securityid] sum := sums[securityid]
(&sum).Add(&sum, amount) (&sum).Add(&sum, &t.Splits[i].Amount.Rat)
sums[securityid] = sum sums[securityid] = sum
} }
return sums, nil return sums, nil

View File

@ -202,6 +202,19 @@ func uploadFile(client *http.Client, filename, urlsuffix string) error {
return nil return nil
} }
func NewAmount(amt string) models.Amount {
var a models.Amount
if _, ok := a.SetString(amt); !ok {
panic("Unable to call Amount.SetString()")
}
return a
}
func amountsMatch(a models.Amount, amt string) bool {
cmp := NewAmount(amt)
return a.Cmp(&cmp.Rat) == 0
}
func accountBalanceHelper(t *testing.T, client *http.Client, account *models.Account, balance string) { func accountBalanceHelper(t *testing.T, client *http.Client, account *models.Account, balance string) {
t.Helper() t.Helper()
transactions, err := getAccountTransactions(client, account.AccountId, 0, 0, "") transactions, err := getAccountTransactions(client, account.AccountId, 0, 0, "")
@ -209,7 +222,7 @@ func accountBalanceHelper(t *testing.T, client *http.Client, account *models.Acc
t.Fatalf("Couldn't fetch account transactions for '%s': %s\n", account.Name, err) t.Fatalf("Couldn't fetch account transactions for '%s': %s\n", account.Name, err)
} }
if transactions.EndingBalance != balance { if !amountsMatch(transactions.EndingBalance, balance) {
t.Errorf("Expected ending balance for '%s' to be '%s', but found %s\n", account.Name, balance, transactions.EndingBalance) t.Errorf("Expected ending balance for '%s' to be '%s', but found %s\n", account.Name, balance, transactions.EndingBalance)
} }
} }

View File

@ -114,11 +114,11 @@ func TestImportGnucash(t *testing.T) {
} }
var p1787, p2894, p3170 bool var p1787, p2894, p3170 bool
for _, price := range *prices.Prices { for _, price := range *prices.Prices {
if price.CurrencyId == d.securities[0].SecurityId && price.Value == "17.87" { if price.CurrencyId == d.securities[0].SecurityId && amountsMatch(price.Value, "17.87") {
p1787 = true p1787 = true
} else if price.CurrencyId == d.securities[0].SecurityId && price.Value == "28.94" { } else if price.CurrencyId == d.securities[0].SecurityId && amountsMatch(price.Value, "28.94") {
p2894 = true p2894 = true
} else if price.CurrencyId == d.securities[0].SecurityId && price.Value == "31.70" { } else if price.CurrencyId == d.securities[0].SecurityId && amountsMatch(price.Value, "31.70") {
p3170 = true p3170 = true
} }
} }

View File

@ -68,7 +68,7 @@ func TestCreatePrice(t *testing.T) {
if !p.Date.Equal(orig.Date) { if !p.Date.Equal(orig.Date) {
t.Errorf("Date doesn't match") t.Errorf("Date doesn't match")
} }
if p.Value != orig.Value { if p.Value.Cmp(&orig.Value.Rat) != 0 {
t.Errorf("Value doesn't match") t.Errorf("Value doesn't match")
} }
if p.RemoteId != orig.RemoteId { if p.RemoteId != orig.RemoteId {
@ -98,7 +98,7 @@ func TestGetPrice(t *testing.T) {
if !p.Date.Equal(orig.Date) { if !p.Date.Equal(orig.Date) {
t.Errorf("Date doesn't match") t.Errorf("Date doesn't match")
} }
if p.Value != orig.Value { if p.Value.Cmp(&orig.Value.Rat) != 0 {
t.Errorf("Value doesn't match") t.Errorf("Value doesn't match")
} }
if p.RemoteId != orig.RemoteId { if p.RemoteId != orig.RemoteId {
@ -132,7 +132,7 @@ func TestGetPrices(t *testing.T) {
found := false found := false
for _, p := range *pl.Prices { for _, p := range *pl.Prices {
if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date.Equal(orig.Date) && p.Value == orig.Value && p.RemoteId == orig.RemoteId { if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date.Equal(orig.Date) && p.Value.Cmp(&orig.Value.Rat) == 0 && p.RemoteId == orig.RemoteId {
if _, ok := foundIds[p.PriceId]; ok { if _, ok := foundIds[p.PriceId]; ok {
continue continue
} }
@ -146,7 +146,11 @@ func TestGetPrices(t *testing.T) {
} }
} }
if numprices != len(*pl.Prices) { if pl.Prices == nil {
if numprices != 0 {
t.Fatalf("Expected %d prices, received 0", numprices)
}
} else if numprices != len(*pl.Prices) {
t.Fatalf("Expected %d prices, received %d", numprices, len(*pl.Prices)) t.Fatalf("Expected %d prices, received %d", numprices, len(*pl.Prices))
} }
} }
@ -162,7 +166,7 @@ func TestUpdatePrice(t *testing.T) {
tmp := curr.SecurityId tmp := curr.SecurityId
curr.SecurityId = curr.CurrencyId curr.SecurityId = curr.CurrencyId
curr.CurrencyId = tmp curr.CurrencyId = tmp
curr.Value = "5.55" curr.Value = NewAmount("5.55")
curr.Date = time.Date(2019, time.June, 5, 12, 5, 6, 7, time.UTC) curr.Date = time.Date(2019, time.June, 5, 12, 5, 6, 7, time.UTC)
curr.RemoteId = "something" curr.RemoteId = "something"
@ -181,7 +185,7 @@ func TestUpdatePrice(t *testing.T) {
if !p.Date.Equal(curr.Date) { if !p.Date.Equal(curr.Date) {
t.Errorf("Date doesn't match") t.Errorf("Date doesn't match")
} }
if p.Value != curr.Value { if p.Value.Cmp(&curr.Value.Rat) != 0 {
t.Errorf("Value doesn't match") t.Errorf("Value doesn't match")
} }
if p.RemoteId != curr.RemoteId { if p.RemoteId != curr.RemoteId {

View File

@ -213,35 +213,35 @@ var data = []TestData{
SecurityId: 1, SecurityId: 1,
CurrencyId: 0, CurrencyId: 0,
Date: time.Date(2017, time.January, 2, 21, 0, 0, 0, time.UTC), Date: time.Date(2017, time.January, 2, 21, 0, 0, 0, time.UTC),
Value: "225.24", Value: NewAmount("225.24"),
RemoteId: "12387-129831-1238", RemoteId: "12387-129831-1238",
}, },
{ {
SecurityId: 1, SecurityId: 1,
CurrencyId: 0, CurrencyId: 0,
Date: time.Date(2017, time.January, 3, 21, 0, 0, 0, time.UTC), Date: time.Date(2017, time.January, 3, 21, 0, 0, 0, time.UTC),
Value: "226.58", Value: NewAmount("226.58"),
RemoteId: "12387-129831-1239", RemoteId: "12387-129831-1239",
}, },
{ {
SecurityId: 1, SecurityId: 1,
CurrencyId: 0, CurrencyId: 0,
Date: time.Date(2017, time.January, 4, 21, 0, 0, 0, time.UTC), Date: time.Date(2017, time.January, 4, 21, 0, 0, 0, time.UTC),
Value: "226.40", Value: NewAmount("226.40"),
RemoteId: "12387-129831-1240", RemoteId: "12387-129831-1240",
}, },
{ {
SecurityId: 1, SecurityId: 1,
CurrencyId: 0, CurrencyId: 0,
Date: time.Date(2017, time.January, 5, 21, 0, 0, 0, time.UTC), Date: time.Date(2017, time.January, 5, 21, 0, 0, 0, time.UTC),
Value: "227.21", Value: NewAmount("227.21"),
RemoteId: "12387-129831-1241", RemoteId: "12387-129831-1241",
}, },
{ {
SecurityId: 0, SecurityId: 0,
CurrencyId: 3, CurrencyId: 3,
Date: time.Date(2017, time.November, 16, 18, 49, 53, 0, time.UTC), Date: time.Date(2017, time.November, 16, 18, 49, 53, 0, time.UTC),
Value: "0.85", Value: NewAmount("0.85"),
RemoteId: "USDEUR819298714", RemoteId: "USDEUR819298714",
}, },
}, },
@ -313,13 +313,13 @@ var data = []TestData{
Status: models.Reconciled, Status: models.Reconciled,
AccountId: 1, AccountId: 1,
SecurityId: -1, SecurityId: -1,
Amount: "-5.6", Amount: NewAmount("-5.6"),
}, },
{ {
Status: models.Reconciled, Status: models.Reconciled,
AccountId: 3, AccountId: 3,
SecurityId: -1, SecurityId: -1,
Amount: "5.6", Amount: NewAmount("5.6"),
}, },
}, },
}, },
@ -332,13 +332,13 @@ var data = []TestData{
Status: models.Reconciled, Status: models.Reconciled,
AccountId: 1, AccountId: 1,
SecurityId: -1, SecurityId: -1,
Amount: "-81.59", Amount: NewAmount("-81.59"),
}, },
{ {
Status: models.Reconciled, Status: models.Reconciled,
AccountId: 3, AccountId: 3,
SecurityId: -1, SecurityId: -1,
Amount: "81.59", Amount: NewAmount("81.59"),
}, },
}, },
}, },
@ -351,13 +351,13 @@ var data = []TestData{
Status: models.Reconciled, Status: models.Reconciled,
AccountId: 1, AccountId: 1,
SecurityId: -1, SecurityId: -1,
Amount: "-39.99", Amount: NewAmount("-39.99"),
}, },
{ {
Status: models.Entered, Status: models.Entered,
AccountId: 4, AccountId: 4,
SecurityId: -1, SecurityId: -1,
Amount: "39.99", Amount: NewAmount("39.99"),
}, },
}, },
}, },
@ -370,13 +370,13 @@ var data = []TestData{
Status: models.Reconciled, Status: models.Reconciled,
AccountId: 5, AccountId: 5,
SecurityId: -1, SecurityId: -1,
Amount: "-24.56", Amount: NewAmount("-24.56"),
}, },
{ {
Status: models.Entered, Status: models.Entered,
AccountId: 6, AccountId: 6,
SecurityId: -1, SecurityId: -1,
Amount: "24.56", Amount: NewAmount("24.56"),
}, },
}, },
}, },

View File

@ -120,7 +120,7 @@ func ensureTransactionsMatch(t *testing.T, expected, tran *models.Transaction, a
origsplit.RemoteId == origsplit.RemoteId && origsplit.RemoteId == origsplit.RemoteId &&
origsplit.Number == s.Number && origsplit.Number == s.Number &&
origsplit.Memo == s.Memo && origsplit.Memo == s.Memo &&
origsplit.Amount == s.Amount && origsplit.Amount.Cmp(&s.Amount.Rat) == 0 &&
(!matchsplitids || origsplit.SplitId == s.SplitId) { (!matchsplitids || origsplit.SplitId == s.SplitId) {
if _, ok := foundIds[s.SplitId]; ok { if _, ok := foundIds[s.SplitId]; ok {
@ -187,13 +187,13 @@ func TestCreateTransaction(t *testing.T) {
Status: models.Reconciled, Status: models.Reconciled,
AccountId: d.accounts[1].AccountId, AccountId: d.accounts[1].AccountId,
SecurityId: -1, SecurityId: -1,
Amount: "-39.98", Amount: NewAmount("-39.98"),
}, },
{ {
Status: models.Entered, Status: models.Entered,
AccountId: d.accounts[4].AccountId, AccountId: d.accounts[4].AccountId,
SecurityId: -1, SecurityId: -1,
Amount: "39.99", Amount: NewAmount("39.99"),
}, },
}, },
} }
@ -333,7 +333,7 @@ func TestUpdateTransaction(t *testing.T) {
tran.UserId = curr.UserId tran.UserId = curr.UserId
// Make sure we can't create an unbalanced transaction // Make sure we can't create an unbalanced transaction
tran.Splits[len(tran.Splits)-1].Amount = "42" tran.Splits[len(tran.Splits)-1].Amount = NewAmount("42")
_, err = updateTransaction(d.clients[orig.UserId], tran) _, err = updateTransaction(d.clients[orig.UserId], tran)
if err == nil { if err == nil {
t.Fatalf("Expected error updating imbalanced transaction") t.Fatalf("Expected error updating imbalanced transaction")

156
internal/models/amounts.go Normal file
View File

@ -0,0 +1,156 @@
package models
import (
"encoding/json"
"fmt"
"math"
"math/big"
"strings"
)
type Amount struct {
big.Rat
}
type PrecisionError struct {
message string
}
func (p PrecisionError) Error() string {
return p.message
}
// Whole returns the integral portion of the Amount
func (amount Amount) Whole() (int64, error) {
var whole big.Int
whole.Quo(amount.Num(), amount.Denom())
if whole.IsInt64() {
return whole.Int64(), nil
}
return 0, PrecisionError{"integral portion of Amount cannot be represented as an int64"}
}
// Fractional returns the fractional portion of the Amount, multiplied by
// 10^precision
func (amount Amount) Fractional(precision uint64) (int64, error) {
if precision < amount.Precision() {
return 0, PrecisionError{"Fractional portion of Amount cannot be represented with the given precision"}
}
// Reduce the fraction to its simplest form
var r, gcd, d, n big.Int
r.Rem(amount.Num(), amount.Denom())
gcd.GCD(nil, nil, &r, amount.Denom())
if gcd.Sign() != 0 {
n.Quo(&r, &gcd)
d.Quo(amount.Denom(), &gcd)
} else {
n.Set(&r)
d.Set(amount.Denom())
}
// Figure out what we need to multiply the numerator by to get the
// denominator to be 10^precision
var prec, multiplier big.Int
prec.SetUint64(precision)
multiplier.SetInt64(10)
multiplier.Exp(&multiplier, &prec, nil)
multiplier.Quo(&multiplier, &d)
n.Mul(&n, &multiplier)
if n.IsInt64() {
return n.Int64(), nil
}
return 0, fmt.Errorf("Fractional portion of Amount does not fit in int64 with given precision")
}
// FromParts re-assembles an Amount from the results from previous calls to
// Whole and Fractional
func (amount *Amount) FromParts(whole, fractional int64, precision uint64) {
var fracnum, fracdenom, power big.Int
fracnum.SetInt64(fractional)
fracdenom.SetInt64(10)
power.SetUint64(precision)
fracdenom.Exp(&fracdenom, &power, nil)
var fracrat big.Rat
fracrat.SetFrac(&fracnum, &fracdenom)
amount.Rat.SetInt64(whole)
amount.Rat.Add(&amount.Rat, &fracrat)
}
// Round rounds the given Amount to the given precision
func (amount *Amount) Round(precision uint64) {
// This probably isn't exactly the most efficient way to do this...
amount.SetString(amount.FloatString(int(precision)))
}
func (amount Amount) String() string {
return amount.FloatString(int(amount.Precision()))
}
func (amount *Amount) UnmarshalJSON(bytes []byte) error {
var value string
if err := json.Unmarshal(bytes, &value); err != nil {
return err
}
value = strings.TrimSpace(value)
if _, ok := amount.SetString(value); !ok {
return fmt.Errorf("Failed to parse '%s' into Amount", value)
}
return nil
}
func (amount Amount) MarshalJSON() ([]byte, error) {
return json.Marshal(amount.String())
}
// Precision returns the minimum positive integer p such that if you multiplied
// this Amount by 10^p, it would become an integer
func (amount Amount) Precision() uint64 {
if amount.IsInt() || amount.Sign() == 0 {
return 0
}
// Find d, the denominator of the reduced fractional portion of 'amount'
var r, gcd, d big.Int
r.Rem(amount.Num(), amount.Denom())
gcd.GCD(nil, nil, &r, amount.Denom())
if gcd.Sign() != 0 {
d.Quo(amount.Denom(), &gcd)
} else {
d.Set(amount.Denom())
}
d.Abs(&d)
var power, result big.Int
one := big.NewInt(1)
ten := big.NewInt(10)
// Estimate an initial power
if d.IsUint64() {
power.SetInt64(int64(math.Log10(float64(d.Uint64()))))
} else {
// If the simplified denominator wasn't a uint64, its > 10^19
power.SetInt64(19)
}
// If the initial estimate was too high, bring it down
result.Exp(ten, &power, nil)
for result.Cmp(&d) > 0 {
power.Sub(&power, one)
result.Exp(ten, &power, nil)
}
// If it was too low, bring it up
for result.Cmp(&d) < 0 {
power.Add(&power, one)
result.Exp(ten, &power, nil)
}
if !power.IsUint64() {
panic("Unable to represent Amount's precision as a uint64")
}
return power.Uint64()
}

View File

@ -0,0 +1,159 @@
package models_test
import (
"github.com/aclindsa/moneygo/internal/models"
"testing"
)
func expectedPrecision(t *testing.T, amount *models.Amount, precision uint64) {
t.Helper()
if amount.Precision() != precision {
t.Errorf("Expected precision %d for %s, found %d", precision, amount.String(), amount.Precision())
}
}
func TestAmountPrecision(t *testing.T) {
var a models.Amount
a.SetString("1.1928712")
expectedPrecision(t, &a, 7)
a.SetString("0")
expectedPrecision(t, &a, 0)
a.SetString("-0.7")
expectedPrecision(t, &a, 1)
a.SetString("-1.1837281037509137509173049173052130957210361309572047598275398265926351231426357130289523647634895285603247284245928712")
expectedPrecision(t, &a, 118)
a.SetInt64(1050)
expectedPrecision(t, &a, 0)
}
func TestAmountRound(t *testing.T) {
var a models.Amount
tests := []struct {
String string
RoundTo uint64
Expected string
}{
{"0", 5, "0"},
{"929.92928", 2, "929.93"},
{"-105.499999", 4, "-105.5"},
{"0.5111111", 1, "0.5"},
{"0.5111111", 0, "1"},
{"9.876456", 3, "9.876"},
}
for _, test := range tests {
a.SetString(test.String)
a.Round(test.RoundTo)
if a.String() != test.Expected {
t.Errorf("Expected '%s' after Round(%d) to be %s intead of %s\n", test.String, test.RoundTo, test.Expected, a.String())
}
}
}
func TestAmountString(t *testing.T) {
var a models.Amount
for _, s := range []string{
"1.1928712",
"0",
"-0.7",
"-1.1837281037509137509173049173052130957210361309572047598275398265926351231426357130289523647634895285603247284245928712",
"1050",
} {
a.SetString(s)
if s != a.String() {
t.Errorf("Expected '%s', found '%s'", s, a.String())
}
}
a.SetString("+182.27")
if "182.27" != a.String() {
t.Errorf("Expected '182.27', found '%s'", a.String())
}
a.SetString("-0")
if "0" != a.String() {
t.Errorf("Expected '0', found '%s'", a.String())
}
}
func TestWhole(t *testing.T) {
var a models.Amount
tests := []struct {
String string
Whole int64
}{
{"0", 0},
{"-0", 0},
{"181.1293871230", 181},
{"-0.1821", 0},
{"99992737.9", 99992737},
{"-7380.000009", -7380},
{"4108740192740912741", 4108740192740912741},
}
for _, test := range tests {
a.SetString(test.String)
val, err := a.Whole()
if err != nil {
t.Errorf("Unexpected error: %s\n", err)
} else if val != test.Whole {
t.Errorf("Expected '%s'.Whole() to return %d intead of %d\n", test.String, test.Whole, val)
}
}
a.SetString("81367662642302823790328492349823472634926342")
_, err := a.Whole()
if err == nil {
t.Errorf("Expected error for overflowing int64")
}
}
func TestFractional(t *testing.T) {
var a models.Amount
tests := []struct {
String string
Precision uint64
Fractional int64
}{
{"0", 5, 0},
{"181.1293871230", 9, 129387123},
{"181.1293871230", 10, 1293871230},
{"181.1293871230", 15, 129387123000000},
{"1828.37", 7, 3700000},
{"-0.748", 5, -74800},
{"-9", 5, 0},
{"-9.9", 1, -9},
}
for _, test := range tests {
a.SetString(test.String)
val, err := a.Fractional(test.Precision)
if err != nil {
t.Errorf("Unexpected error: %s\n", err)
} else if val != test.Fractional {
t.Errorf("Expected '%s'.Fractional(%d) to return %d intead of %d\n", test.String, test.Precision, test.Fractional, val)
}
}
}
func TestFromParts(t *testing.T) {
var a models.Amount
tests := []struct {
Whole int64
Fractional int64
Precision uint64
Result string
}{
{839, 9080, 4, "839.908"},
{-10, 0, 5, "-10"},
{0, 900, 10, "0.00000009"},
{9128713621, 87272727, 20, "9128713621.00000000000087272727"},
{89, 1, 0, "90"}, // Not sure if this should really be supported, but it is
}
for _, test := range tests {
a.FromParts(test.Whole, test.Fractional, test.Precision)
if a.String() != test.Result {
t.Errorf("Expected Amount.FromParts(%d, %d, %d) to return %s intead of %s\n", test.Whole, test.Fractional, test.Precision, test.Result, a.String())
}
}
}

View File

@ -12,7 +12,7 @@ type Price struct {
SecurityId int64 SecurityId int64
CurrencyId int64 CurrencyId int64
Date time.Time Date time.Time
Value string // String representation of decimal price of Security in Currency units, suitable for passing to big.Rat.SetString() Value Amount // price of Security in Currency units
RemoteId string // unique ID from source, for detecting duplicates RemoteId string // unique ID from source, for detecting duplicates
} }

View File

@ -23,6 +23,9 @@ func GetSecurityType(typestring string) SecurityType {
} }
} }
// MaxPrexision denotes the maximum valid value for Security.Precision
const MaxPrecision uint64 = 15
type Security struct { type Security struct {
SecurityId int64 SecurityId int64
UserId int64 UserId int64
@ -31,7 +34,7 @@ type Security struct {
Symbol string Symbol string
// Number of decimal digits (to the right of the decimal point) this // Number of decimal digits (to the right of the decimal point) this
// security is precise to // security is precise to
Precision int `db:"Preciseness"` Precision uint64 `db:"Preciseness"`
Type SecurityType Type SecurityType
// AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency // AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency
AlternateId string AlternateId string

View File

@ -2,8 +2,6 @@ package models
import ( import (
"encoding/json" "encoding/json"
"errors"
"math/big"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -49,28 +47,11 @@ type Split struct {
RemoteId string // unique ID from server, for detecting duplicates RemoteId string // unique ID from server, for detecting duplicates
Number string // Check or reference number Number string // Check or reference number
Memo string Memo string
Amount string // String representation of decimal, suitable for passing to big.Rat.SetString() Amount Amount
}
func GetBigAmount(amt string) (*big.Rat, error) {
var r big.Rat
_, success := r.SetString(amt)
if !success {
return nil, errors.New("Couldn't convert string amount to big.Rat via SetString()")
}
return &r, nil
}
func (s *Split) GetAmount() (*big.Rat, error) {
return GetBigAmount(s.Amount)
} }
func (s *Split) Valid() bool { func (s *Split) Valid() bool {
if (s.AccountId == -1) == (s.SecurityId == -1) { return (s.AccountId == -1) != (s.SecurityId == -1)
return false
}
_, err := s.GetAmount()
return err == nil
} }
type Transaction struct { type Transaction struct {
@ -89,8 +70,8 @@ type AccountTransactionsList struct {
Account *Account Account *Account
Transactions *[]*Transaction Transactions *[]*Transaction
TotalTransactions int64 TotalTransactions int64
BeginningBalance string BeginningBalance Amount
EndingBalance string EndingBalance Amount
} }
func (t *Transaction) Write(w http.ResponseWriter) error { func (t *Transaction) Write(w http.ResponseWriter) error {

View File

@ -6,7 +6,6 @@ import (
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/moneygo/internal/store"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
"math/big"
"strings" "strings"
) )
@ -147,20 +146,6 @@ func luaAccount__index(L *lua.LState) int {
return 1 return 1
} }
func balanceFromSplits(splits *[]*models.Split) (*big.Rat, error) {
var balance, tmp big.Rat
for _, s := range *splits {
rat_amount, err := models.GetBigAmount(s.Amount)
if err != nil {
return nil, err
}
tmp.Add(&balance, rat_amount)
balance.Set(&tmp)
}
return &balance, nil
}
func luaAccountBalance(L *lua.LState) int { func luaAccountBalance(L *lua.LState) int {
a := luaCheckAccount(L, 1) a := luaCheckAccount(L, 1)
@ -182,28 +167,25 @@ func luaAccountBalance(L *lua.LState) int {
panic("SecurityId not in lua security_map") panic("SecurityId not in lua security_map")
} }
date := luaWeakCheckTime(L, 2) date := luaWeakCheckTime(L, 2)
var splits *[]*models.Split var balance *models.Amount
if date != nil { if date != nil {
end := luaWeakCheckTime(L, 3) end := luaWeakCheckTime(L, 3)
if end != nil { if end != nil {
splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end) balance, err = tx.GetAccountBalanceDateRange(user, a.AccountId, date, end)
} else { } else {
splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date) balance, err = tx.GetAccountBalanceDate(user, a.AccountId, date)
} }
} else { } else {
splits, err = tx.GetAccountSplits(user, a.AccountId) balance, err = tx.GetAccountBalance(user, a.AccountId)
} }
if err != nil { if err != nil {
panic("Failed to fetch splits for account:" + err.Error()) panic("Failed to fetch balance for account:" + err.Error())
}
rat, err := balanceFromSplits(splits)
if err != nil {
panic("Failed to calculate balance for account:" + err.Error())
} }
b := &Balance{ b := &Balance{
Amount: rat, Amount: *balance,
Security: security, Security: security,
} }
L.Push(BalanceToLua(L, b)) L.Push(BalanceToLua(L, b))
return 1 return 1

View File

@ -3,12 +3,11 @@ package reports
import ( import (
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
"math/big"
) )
type Balance struct { type Balance struct {
Security *models.Security Security *models.Security
Amount *big.Rat Amount models.Amount
} }
const luaBalanceTypeName = "balance" const luaBalanceTypeName = "balance"
@ -66,10 +65,8 @@ func luaGetBalanceOperands(L *lua.LState, m int, n int) (*Balance, *Balance) {
} else if bm != nil { } else if bm != nil {
nn := L.CheckNumber(n) nn := L.CheckNumber(n)
var balance Balance var balance Balance
var rat big.Rat
balance.Security = bm.Security balance.Security = bm.Security
balance.Amount = rat.SetFloat64(float64(nn)) if balance.Amount.SetFloat64(float64(nn)) == nil {
if balance.Amount == nil {
L.ArgError(n, "non-finite float invalid for operand to balance arithemetic") L.ArgError(n, "non-finite float invalid for operand to balance arithemetic")
return nil, nil return nil, nil
} }
@ -77,10 +74,8 @@ func luaGetBalanceOperands(L *lua.LState, m int, n int) (*Balance, *Balance) {
} else if bn != nil { } else if bn != nil {
nm := L.CheckNumber(m) nm := L.CheckNumber(m)
var balance Balance var balance Balance
var rat big.Rat
balance.Security = bn.Security balance.Security = bn.Security
balance.Amount = rat.SetFloat64(float64(nm)) if balance.Amount.SetFloat64(float64(nm)) == nil {
if balance.Amount == nil {
L.ArgError(m, "non-finite float invalid for operand to balance arithemetic") L.ArgError(m, "non-finite float invalid for operand to balance arithemetic")
return nil, nil return nil, nil
} }
@ -110,7 +105,7 @@ func luaBalance__index(L *lua.LState) int {
func luaBalance__tostring(L *lua.LState) int { func luaBalance__tostring(L *lua.LState) int {
b := luaCheckBalance(L, 1) b := luaCheckBalance(L, 1)
L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(b.Security.Precision))) L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(int(b.Security.Precision))))
return 1 return 1
} }
@ -119,7 +114,7 @@ func luaBalance__eq(L *lua.LState) int {
a := luaCheckBalance(L, 1) a := luaCheckBalance(L, 1)
b := luaCheckBalance(L, 2) b := luaCheckBalance(L, 2)
L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(b.Amount) == 0)) L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(&b.Amount.Rat) == 0))
return 1 return 1
} }
@ -131,7 +126,7 @@ func luaBalance__lt(L *lua.LState) int {
L.ArgError(2, "Can't compare balances with different securities") L.ArgError(2, "Can't compare balances with different securities")
} }
L.Push(lua.LBool(a.Amount.Cmp(b.Amount) < 0)) L.Push(lua.LBool(a.Amount.Cmp(&b.Amount.Rat) < 0))
return 1 return 1
} }
@ -143,7 +138,7 @@ func luaBalance__le(L *lua.LState) int {
L.ArgError(2, "Can't compare balances with different securities") L.ArgError(2, "Can't compare balances with different securities")
} }
L.Push(lua.LBool(a.Amount.Cmp(b.Amount) <= 0)) L.Push(lua.LBool(a.Amount.Cmp(&b.Amount.Rat) <= 0))
return 1 return 1
} }
@ -156,9 +151,8 @@ func luaBalance__add(L *lua.LState) int {
} }
var balance Balance var balance Balance
var rat big.Rat
balance.Security = a.Security balance.Security = a.Security
balance.Amount = rat.Add(a.Amount, b.Amount) balance.Amount.Add(&a.Amount.Rat, &b.Amount.Rat)
L.Push(BalanceToLua(L, &balance)) L.Push(BalanceToLua(L, &balance))
return 1 return 1
@ -172,9 +166,8 @@ func luaBalance__sub(L *lua.LState) int {
} }
var balance Balance var balance Balance
var rat big.Rat
balance.Security = a.Security balance.Security = a.Security
balance.Amount = rat.Sub(a.Amount, b.Amount) balance.Amount.Sub(&a.Amount.Rat, &b.Amount.Rat)
L.Push(BalanceToLua(L, &balance)) L.Push(BalanceToLua(L, &balance))
return 1 return 1
@ -188,9 +181,8 @@ func luaBalance__mul(L *lua.LState) int {
} }
var balance Balance var balance Balance
var rat big.Rat
balance.Security = a.Security balance.Security = a.Security
balance.Amount = rat.Mul(a.Amount, b.Amount) balance.Amount.Mul(&a.Amount.Rat, &b.Amount.Rat)
L.Push(BalanceToLua(L, &balance)) L.Push(BalanceToLua(L, &balance))
return 1 return 1
@ -204,9 +196,8 @@ func luaBalance__div(L *lua.LState) int {
} }
var balance Balance var balance Balance
var rat big.Rat
balance.Security = a.Security balance.Security = a.Security
balance.Amount = rat.Quo(a.Amount, b.Amount) balance.Amount.Quo(&a.Amount.Rat, &b.Amount.Rat)
L.Push(BalanceToLua(L, &balance)) L.Push(BalanceToLua(L, &balance))
return 1 return 1
@ -216,9 +207,8 @@ func luaBalance__unm(L *lua.LState) int {
b := luaCheckBalance(L, 1) b := luaCheckBalance(L, 1)
var balance Balance var balance Balance
var rat big.Rat
balance.Security = b.Security balance.Security = b.Security
balance.Amount = rat.Neg(b.Amount) balance.Amount.Neg(&b.Amount.Rat)
L.Push(BalanceToLua(L, &balance)) L.Push(BalanceToLua(L, &balance))
return 1 return 1

View File

@ -60,11 +60,7 @@ func luaPrice__index(L *lua.LState) int {
} }
L.Push(SecurityToLua(L, c)) L.Push(SecurityToLua(L, c))
case "Value", "value": case "Value", "value":
amt, err := models.GetBigAmount(p.Value) float, _ := p.Value.Float64()
if err != nil {
panic(err)
}
float, _ := amt.Float64()
L.Push(lua.LNumber(float)) L.Push(lua.LNumber(float))
default: default:
L.ArgError(2, "unexpected price attribute: "+field) L.ArgError(2, "unexpected price attribute: "+field)
@ -86,7 +82,7 @@ func luaPrice__tostring(L *lua.LState) int {
panic("Price's currency or security not found for user") panic("Price's currency or security not found for user")
} }
L.Push(lua.LString(p.Value + " " + c.Symbol + " (" + s.Symbol + ")")) L.Push(lua.LString(p.Value.String() + " " + c.Symbol + " (" + s.Symbol + ")"))
return 1 return 1
} }

View File

@ -40,10 +40,10 @@ func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId")
dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId")
dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId")
dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") dbmap.AddTableWithName(Price{}, "prices").SetKeys(true, "PriceId")
dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId")
dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId")
dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(Split{}, "splits").SetKeys(true, "SplitId")
rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId") rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId")
rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer) rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer)

View File

@ -6,73 +6,150 @@ import (
"time" "time"
) )
// Price is a mirror of models.Price with the Value broken out into whole and
// fractional components
type Price struct {
PriceId int64
SecurityId int64
CurrencyId int64
Date time.Time
WholeValue int64
FractionalValue int64
RemoteId string // unique ID from source, for detecting duplicates
}
func NewPrice(p *models.Price) (*Price, error) {
whole, err := p.Value.Whole()
if err != nil {
return nil, err
}
fractional, err := p.Value.Fractional(MaxPrecision)
if err != nil {
return nil, err
}
return &Price{
PriceId: p.PriceId,
SecurityId: p.SecurityId,
CurrencyId: p.CurrencyId,
Date: p.Date,
WholeValue: whole,
FractionalValue: fractional,
RemoteId: p.RemoteId,
}, nil
}
func (p Price) Price() *models.Price {
price := &models.Price{
PriceId: p.PriceId,
SecurityId: p.SecurityId,
CurrencyId: p.CurrencyId,
Date: p.Date,
RemoteId: p.RemoteId,
}
price.Value.FromParts(p.WholeValue, p.FractionalValue, MaxPrecision)
return price
}
func (tx *Tx) PriceExists(price *models.Price) (bool, error) { func (tx *Tx) PriceExists(price *models.Price) (bool, error) {
var prices []*models.Price p, err := NewPrice(price)
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) if err != nil {
return false, err
}
var prices []*Price
_, err = tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND WholeValue=? AND FractionalValue=?", p.SecurityId, p.CurrencyId, p.Date, p.WholeValue, p.FractionalValue)
return len(prices) > 0, err return len(prices) > 0, err
} }
func (tx *Tx) InsertPrice(price *models.Price) error { func (tx *Tx) InsertPrice(price *models.Price) error {
return tx.Insert(price) p, err := NewPrice(price)
if err != nil {
return err
}
err = tx.Insert(p)
if err != nil {
return err
}
*price = *p.Price()
return nil
} }
func (tx *Tx) GetPrice(priceid, securityid int64) (*models.Price, error) { func (tx *Tx) GetPrice(priceid, securityid int64) (*models.Price, error) {
var price models.Price var price Price
err := tx.SelectOne(&price, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) err := tx.SelectOne(&price, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &price, nil return price.Price(), nil
} }
func (tx *Tx) GetPrices(securityid int64) (*[]*models.Price, error) { func (tx *Tx) GetPrices(securityid int64) (*[]*models.Price, error) {
var prices []*models.Price var prices []*Price
var modelprices []*models.Price
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &prices, nil
for _, p := range prices {
modelprices = append(modelprices, p.Price())
}
return &modelprices, nil
} }
// Return the latest price for security in currency units before date // Return the latest price for security in currency units before date
func (tx *Tx) GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) { func (tx *Tx) GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) {
var price models.Price var price Price
err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &price, nil return price.Price(), nil
} }
// Return the earliest price for security in currency units after date // Return the earliest price for security in currency units after date
func (tx *Tx) GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) { func (tx *Tx) GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) {
var price models.Price var price Price
err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &price, nil return price.Price(), nil
} }
func (tx *Tx) UpdatePrice(price *models.Price) error { func (tx *Tx) UpdatePrice(price *models.Price) error {
count, err := tx.Update(price) p, err := NewPrice(price)
if err != nil {
return err
}
count, err := tx.Update(p)
if err != nil { if err != nil {
return err return err
} }
if count != 1 { if count != 1 {
return fmt.Errorf("Expected to update 1 price, was going to update %d", count) return fmt.Errorf("Expected to update 1 price, was going to update %d", count)
} }
*price = *p.Price()
return nil return nil
} }
func (tx *Tx) DeletePrice(price *models.Price) error { func (tx *Tx) DeletePrice(price *models.Price) error {
count, err := tx.Delete(price) p, err := NewPrice(price)
if err != nil {
return err
}
count, err := tx.Delete(p)
if err != nil { if err != nil {
return err return err
} }
if count != 1 { if count != 1 {
return fmt.Errorf("Expected to delete 1 price, was going to delete %d", count) return fmt.Errorf("Expected to delete 1 price, was going to delete %d", count)
} }
*price = *p.Price()
return nil return nil
} }

View File

@ -6,6 +6,17 @@ import (
"github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/moneygo/internal/store"
) )
// MaxPrexision denotes the maximum valid value for models.Security.Precision.
// This constant is used when storing amounts in securities into the database,
// so it must not be changed without appropriately migrating the database.
const MaxPrecision uint64 = 15
func init() {
if MaxPrecision < models.MaxPrecision {
panic("db.MaxPrecision must be >= models.MaxPrecision")
}
}
func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) { func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) {
var s models.Security var s models.Security

View File

@ -9,6 +9,71 @@ import (
"time" "time"
) )
// Split is a mirror of models.Split with the Amount broken out into whole and
// fractional components
type Split struct {
SplitId int64
TransactionId int64
Status int64
ImportSplitType int64
// One of AccountId and SecurityId must be -1
// In normal splits, AccountId will be valid and SecurityId will be -1. The
// only case where this is reversed is for transactions that have been
// imported and not yet associated with an account.
AccountId int64
SecurityId int64
RemoteId string // unique ID from server, for detecting duplicates
Number string // Check or reference number
Memo string
// Amount.Whole and Amount.Fractional(MaxPrecision)
WholeAmount int64
FractionalAmount int64
}
func NewSplit(s *models.Split) (*Split, error) {
whole, err := s.Amount.Whole()
if err != nil {
return nil, err
}
fractional, err := s.Amount.Fractional(MaxPrecision)
if err != nil {
return nil, err
}
return &Split{
SplitId: s.SplitId,
TransactionId: s.TransactionId,
Status: s.Status,
ImportSplitType: s.ImportSplitType,
AccountId: s.AccountId,
SecurityId: s.SecurityId,
RemoteId: s.RemoteId,
Number: s.Number,
Memo: s.Memo,
WholeAmount: whole,
FractionalAmount: fractional,
}, nil
}
func (s Split) Split() *models.Split {
split := &models.Split{
SplitId: s.SplitId,
TransactionId: s.TransactionId,
Status: s.Status,
ImportSplitType: s.ImportSplitType,
AccountId: s.AccountId,
SecurityId: s.SecurityId,
RemoteId: s.RemoteId,
Number: s.Number,
Memo: s.Memo,
}
split.Amount.FromParts(s.WholeAmount, s.FractionalAmount, MaxPrecision)
return split
}
func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error { func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error {
for i := range accountids { for i := range accountids {
account, err := tx.GetAccount(accountids[i], user.UserId) account, err := tx.GetAccount(accountids[i], user.UserId)
@ -68,10 +133,15 @@ func (tx *Tx) InsertTransaction(t *models.Transaction, user *models.User) error
for i := range t.Splits { for i := range t.Splits {
t.Splits[i].TransactionId = t.TransactionId t.Splits[i].TransactionId = t.TransactionId
t.Splits[i].SplitId = -1 t.Splits[i].SplitId = -1
err = tx.Insert(t.Splits[i]) s, err := NewSplit(t.Splits[i])
if err != nil { if err != nil {
return err return err
} }
err = tx.Insert(s)
if err != nil {
return err
}
*t.Splits[i] = *s.Split()
} }
return nil return nil
@ -84,17 +154,22 @@ func (tx *Tx) SplitExists(s *models.Split) (bool, error) {
func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) { func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) {
var t models.Transaction var t models.Transaction
var splits []*Split
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) _, err = tx.Select(&splits, "SELECT * from splits where TransactionId=?", transactionid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, split := range splits {
t.Splits = append(t.Splits, split.Split())
}
return &t, nil return &t, nil
} }
@ -107,17 +182,21 @@ func (tx *Tx) GetTransactions(userid int64) (*[]*models.Transaction, error) {
} }
for i := range transactions { for i := range transactions {
_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) var splits []*Split
_, err := tx.Select(&splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, split := range splits {
transactions[i].Splits = append(transactions[i].Splits, split.Split())
}
} }
return &transactions, nil return &transactions, nil
} }
func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error { func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error {
var existing_splits []*models.Split var existing_splits []*Split
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
if err != nil { if err != nil {
@ -136,25 +215,30 @@ func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error
// Insert splits, updating any pre-existing ones // Insert splits, updating any pre-existing ones
for i := range t.Splits { for i := range t.Splits {
t.Splits[i].TransactionId = t.TransactionId t.Splits[i].TransactionId = t.TransactionId
_, ok := s_map[t.Splits[i].SplitId] s, err := NewSplit(t.Splits[i])
if err != nil {
return err
}
_, ok := s_map[s.SplitId]
if ok { if ok {
count, err := tx.Update(t.Splits[i]) count, err := tx.Update(s)
if err != nil { if err != nil {
return err return err
} }
if count > 1 { if count > 1 {
return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count) return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count)
} }
delete(s_map, t.Splits[i].SplitId) delete(s_map, s.SplitId)
} else { } else {
t.Splits[i].SplitId = -1 s.SplitId = -1
err := tx.Insert(t.Splits[i]) err := tx.Insert(s)
if err != nil { if err != nil {
return err return err
} }
} }
*t.Splits[i] = *s.Split()
if t.Splits[i].AccountId != -1 { if t.Splits[i].AccountId != -1 {
a_map[t.Splits[i].AccountId] = true a_map[s.AccountId] = true
} }
} }
@ -221,58 +305,58 @@ func (tx *Tx) DeleteTransaction(t *models.Transaction, user *models.User) error
return nil return nil
} }
func (tx *Tx) GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) {
var splits []*models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
_, err := tx.Select(&splits, sql, accountid, user.UserId)
if err != nil {
return nil, err
}
return &splits, nil
}
// Assumes accountid is valid and is owned by the current user // Assumes accountid is valid and is owned by the current user
func (tx *Tx) GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) { func (tx *Tx) getAccountBalance(xtrasql string, args ...interface{}) (*models.Amount, error) {
var splits []*models.Split var balance models.Amount
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" sql := "FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" + xtrasql
_, err := tx.Select(&splits, sql, accountid, user.UserId, date) count, err := tx.SelectInt("SELECT splits.SplitId "+sql+" LIMIT 1", args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &splits, err if count > 0 {
type bal struct {
Whole, Fractional int64
}
var b bal
err := tx.SelectOne(&b, "SELECT sum(splits.WholeAmount) AS Whole, sum(splits.FractionalAmount) AS Fractional "+sql, args...)
if err != nil {
return nil, err
}
balance.FromParts(b.Whole, b.Fractional, MaxPrecision)
}
return &balance, nil
} }
func (tx *Tx) GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) { func (tx *Tx) GetAccountBalance(user *models.User, accountid int64) (*models.Amount, error) {
var splits []*models.Split return tx.getAccountBalance("", accountid, user.UserId)
}
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" func (tx *Tx) GetAccountBalanceDate(user *models.User, accountid int64, date *time.Time) (*models.Amount, error) {
_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) return tx.getAccountBalance(" AND transactions.date < ?", accountid, user.UserId, date)
if err != nil { }
return nil, err
} func (tx *Tx) GetAccountBalanceDateRange(user *models.User, accountid int64, begin, end *time.Time) (*models.Amount, error) {
return &splits, nil return tx.getAccountBalance(" AND transactions.date >= ? AND transactions.Date < ?", accountid, user.UserId, begin, end)
} }
func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) { func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) {
var pageDifference, tmp big.Rat var pageDifference big.Rat
for i := range transactions { for i := range transactions {
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) var splits []*Split
_, err := tx.Select(&splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Sum up the amounts from the splits we're returning so we can return // Sum up the amounts from the splits we're returning so we can return
// an ending balance // an ending balance
for j := range transactions[i].Splits { for j, s := range splits {
transactions[i].Splits = append(transactions[i].Splits, s.Split())
if transactions[i].Splits[j].AccountId == accountid { if transactions[i].Splits[j].AccountId == accountid {
rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) pageDifference.Add(&pageDifference, &transactions[i].Splits[j].Amount.Rat)
if err != nil {
return nil, err
}
tmp.Add(&pageDifference, rat_amount)
pageDifference.Set(&tmp)
} }
} }
} }
@ -338,24 +422,31 @@ func (tx *Tx) GetAccountTransactions(user *models.User, accountid int64, sort st
// Sum all the splits for all transaction splits for this account that // Sum all the splits for all transaction splits for this account that
// occurred before the page we're returning // occurred before the page we're returning
var amounts []string sql = "FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" count, err = tx.SelectInt("SELECT count(*) "+sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var tmp, balance big.Rat var balance models.Amount
for _, amount := range amounts {
rat_amount, err := models.GetBigAmount(amount) // Don't attempt to 'sum()' the splits if none exist, because it is
// supposed to return null/nil in this case, which makes gorp angry since
// we're using SelectInt()
if count > 0 {
whole, err := tx.SelectInt("SELECT sum(s.WholeAmount) "+sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tmp.Add(&balance, rat_amount) fractional, err := tx.SelectInt("SELECT sum(s.FractionalAmount) "+sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
balance.Set(&tmp) if err != nil {
return nil, err
} }
atl.BeginningBalance = balance.FloatString(security.Precision) balance.FromParts(whole, fractional, MaxPrecision)
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) }
atl.BeginningBalance = balance
atl.EndingBalance.Rat.Add(&balance.Rat, pageDifference)
return &atl, nil return &atl, nil
} }

View File

@ -89,9 +89,9 @@ type TransactionStore interface {
GetTransactions(userid int64) (*[]*models.Transaction, error) GetTransactions(userid int64) (*[]*models.Transaction, error)
UpdateTransaction(t *models.Transaction, user *models.User) error UpdateTransaction(t *models.Transaction, user *models.User) error
DeleteTransaction(t *models.Transaction, user *models.User) error DeleteTransaction(t *models.Transaction, user *models.User) error
GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) GetAccountBalance(user *models.User, accountid int64) (*models.Amount, error)
GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) GetAccountBalanceDate(user *models.User, accountid int64, date *time.Time) (*models.Amount, error)
GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) GetAccountBalanceDateRange(user *models.User, accountid int64, begin, end *time.Time) (*models.Amount, error)
GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error)
} }