Add the ability to marshal a Response to SGML/XML and test it

This allows for ofxgo to be used to create well-formatted OFX from poor
OFX, or even be used to generate OFX from other formats for easier
importing into financial management software.

Test this functionality by adding "round trip" testing to all existing
tests - ensure that responses' content is the same after a round trip of
marshalling and unmarshalling them.
This commit is contained in:
Aaron Lindsay 2019-03-01 22:40:49 -05:00
parent 286e619071
commit 35c7116654
10 changed files with 375 additions and 1 deletions

View File

@ -283,4 +283,5 @@ func TestUnmarshalBankStatementResponse(t *testing.T) {
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}

View File

@ -161,4 +161,5 @@ NEWFILEUID:NONE
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}

View File

@ -465,6 +465,7 @@ type InvBankTransaction struct {
// security-related transactions themselves. It must be unmarshalled manually
// due to the structure (don't know what kind of InvTransaction is coming next)
type InvTranList struct {
XMLName xml.Name `xml:"INVTRANLIST"`
DtStart Date
DtEnd Date // This is the value that should be sent as <DTSTART> in the next InvStatementRequest to ensure that no transactions are missed
InvTransactions []InvTransaction
@ -630,6 +631,119 @@ func (l *InvTranList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error
}
}
// MarshalXML handles marshalling an InvTranList element to an SGML/XML string
func (l *InvTranList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
invTranListElement := xml.StartElement{Name: xml.Name{Local: "INVTRANLIST"}}
if err := e.EncodeToken(invTranListElement); err != nil {
return err
}
err := e.EncodeElement(&l.DtStart, xml.StartElement{Name: xml.Name{Local: "DTSTART"}})
if err != nil {
return err
}
err = e.EncodeElement(&l.DtEnd, xml.StartElement{Name: xml.Name{Local: "DTEND"}})
if err != nil {
return err
}
for _, t := range l.InvTransactions {
start := xml.StartElement{Name: xml.Name{Local: t.TransactionType()}}
switch tran := t.(type) {
case BuyDebt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case BuyMF:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case BuyOpt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case BuyOther:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case BuyStock:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case ClosureOpt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case Income:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case InvExpense:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case JrnlFund:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case JrnlSec:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case MarginInterest:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case Reinvest:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case RetOfCap:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellDebt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellMF:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellOpt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellOther:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellStock:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case Split:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case Transfer:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
default:
return errors.New("Invalid INVTRANLIST child type: " + tran.TransactionType())
}
}
for _, tran := range l.BankTransactions {
err = e.EncodeElement(&tran, xml.StartElement{Name: xml.Name{Local: "INVBANKTRAN"}})
if err != nil {
return err
}
}
if err := e.EncodeToken(invTranListElement.End()); err != nil {
return err
}
return nil
}
// InvPosition contains generic position information included in each of the
// other *Position types
type InvPosition struct {
@ -770,6 +884,45 @@ func (p *PositionList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) erro
}
}
// MarshalXML handles marshalling a PositionList to an XML string
func (p *PositionList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
invPosListElement := xml.StartElement{Name: xml.Name{Local: "INVPOSLIST"}}
if err := e.EncodeToken(invPosListElement); err != nil {
return err
}
for _, position := range *p {
start := xml.StartElement{Name: xml.Name{Local: position.PositionType()}}
switch pos := position.(type) {
case DebtPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
case MFPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
case OptPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
case OtherPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
case StockPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
default:
return errors.New("Invalid INVPOSLIST child type: " + pos.PositionType())
}
}
if err := e.EncodeToken(invPosListElement.End()); err != nil {
return err
}
return nil
}
// InvBalance contains three (or optionally four) specified balances as well as
// a free-form list of generic balance information which may be provided by an
// FI.
@ -1036,6 +1189,69 @@ func (o *OOList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
}
}
// MarshalXML handles marshalling an OOList to an XML string
func (o *OOList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
ooListElement := xml.StartElement{Name: xml.Name{Local: "INVOOLIST"}}
if err := e.EncodeToken(ooListElement); err != nil {
return err
}
for _, openorder := range *o {
start := xml.StartElement{Name: xml.Name{Local: openorder.OrderType()}}
switch oo := openorder.(type) {
case OOBuyDebt:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOBuyMF:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOBuyOpt:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOBuyOther:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOBuyStock:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellDebt:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellMF:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellOpt:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellOther:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellStock:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSwitchMF:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
default:
return errors.New("Invalid OOLIST child type: " + oo.OrderType())
}
}
if err := e.EncodeToken(ooListElement.End()); err != nil {
return err
}
return nil
}
// ContribSecurity identifies current contribution allocation for a security in
// a 401(k) account
type ContribSecurity struct {

View File

@ -602,6 +602,7 @@ func TestUnmarshalInvStatementResponse(t *testing.T) {
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}
func TestUnmarshalInvStatementResponse102(t *testing.T) {
@ -957,6 +958,7 @@ NEWFILEUID: NONE
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}
func TestUnmarshalInvTranList(t *testing.T) {

View File

@ -3,6 +3,7 @@ package ofxgo
import (
"errors"
"github.com/aclindsa/xml"
"strings"
)
// ProfileRequest represents a request for a server to provide a profile of its
@ -126,6 +127,35 @@ func (msl *MessageSetList) UnmarshalXML(d *xml.Decoder, start xml.StartElement)
}
}
// MarshalXML handles marshalling a MessageSetList element to an XML string
func (msl *MessageSetList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
messageSetListElement := xml.StartElement{Name: xml.Name{Local: "MSGSETLIST"}}
if err := e.EncodeToken(messageSetListElement); err != nil {
return err
}
for _, messageset := range *msl {
if !strings.HasSuffix(messageset.Name, "V1") {
return errors.New("Expected MessageSet.Name to end with \"V1\"")
}
messageSetName := strings.TrimSuffix(messageset.Name, "V1")
messageSetElement := xml.StartElement{Name: xml.Name{Local: messageSetName}}
if err := e.EncodeToken(messageSetElement); err != nil {
return err
}
start := xml.StartElement{Name: xml.Name{Local: messageset.Name}}
if err := e.EncodeElement(&messageset, start); err != nil {
return err
}
if err := e.EncodeToken(messageSetElement.End()); err != nil {
return err
}
}
if err := e.EncodeToken(messageSetListElement.End()); err != nil {
return err
}
return nil
}
// ProfileResponse contains a requested profile of the server's capabilities
// (which message sets and versions it supports, how to access them, which
// languages and which types of synchronization they support, etc.). Note that

View File

@ -325,4 +325,5 @@ NEWFILEUID:NONE
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}

View File

@ -369,3 +369,70 @@ func ParseResponse(reader io.Reader) (*Response, error) {
}
}
}
// Marshal this Response into its SGML/XML representation held in a bytes.Buffer
//
// If error is non-nil, this bytes.Buffer is ready to be sent to an OFX client
func (or *Response) Marshal() (*bytes.Buffer, error) {
var b bytes.Buffer
// Write the header appropriate to our version
writeHeader(&b, or.Version)
encoder := xml.NewEncoder(&b)
encoder.Indent("", " ")
ofxElement := xml.StartElement{Name: xml.Name{Local: "OFX"}}
if err := encoder.EncodeToken(ofxElement); err != nil {
return nil, err
}
if ok, err := or.Signon.Valid(or.Version); !ok {
return nil, err
}
signonMsgSet := xml.StartElement{Name: xml.Name{Local: SignonRs.String()}}
if err := encoder.EncodeToken(signonMsgSet); err != nil {
return nil, err
}
if err := encoder.Encode(&or.Signon); err != nil {
return nil, err
}
if err := encoder.EncodeToken(signonMsgSet.End()); err != nil {
return nil, err
}
messageSets := []struct {
Messages []Message
Type messageType
}{
{or.Signup, SignupRs},
{or.Bank, BankRs},
{or.CreditCard, CreditCardRs},
{or.Loan, LoanRs},
{or.InvStmt, InvStmtRs},
{or.InterXfer, InterXferRs},
{or.WireXfer, WireXferRs},
{or.Billpay, BillpayRs},
{or.Email, EmailRs},
{or.SecList, SecListRs},
{or.PresDir, PresDirRs},
{or.PresDlv, PresDlvRs},
{or.Prof, ProfRs},
{or.Image, ImageRs},
}
for _, set := range messageSets {
if err := encodeMessageSet(encoder, set.Messages, set.Type, or.Version); err != nil {
return nil, err
}
}
if err := encoder.EncodeToken(ofxElement.End()); err != nil {
return nil, err
}
if err := encoder.Flush(); err != nil {
return nil, err
}
return &b, nil
}

View File

@ -136,6 +136,20 @@ func checkResponsesEqual(t *testing.T, expected, actual *ofxgo.Response) {
checkEqual(t, "", reflect.ValueOf(expected), reflect.ValueOf(actual))
}
func checkResponseRoundTrip(t *testing.T, response *ofxgo.Response) {
b, err := response.Marshal()
if err != nil {
t.Fatalf("Unexpected error re-marshaling OFX response: %s\n", err)
}
roundtripped, err := ofxgo.ParseResponse(b)
if err != nil {
t.Fatalf("Unexpected error re-parsing OFX response: %s\n", err)
}
checkResponsesEqual(t, response, roundtripped)
}
// Ensure that these samples both parse without errors, and can be converted
// back and forth without changing.
func TestValidSamples(t *testing.T) {
fn := func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
@ -147,10 +161,11 @@ func TestValidSamples(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error opening %s: %s\n", path, err)
}
_, err = ofxgo.ParseResponse(file)
response, err := ofxgo.ParseResponse(file)
if err != nil {
t.Fatalf("Unexpected error parsing OFX response in %s: %s\n", path, err)
}
checkResponseRoundTrip(t, response)
return nil
}
filepath.Walk("samples/valid_responses", fn)

View File

@ -221,6 +221,7 @@ func (i StockInfo) SecurityType() string {
// SecurityList is a container for Security objects containaing information
// about securities
type SecurityList struct {
XMLName xml.Name `xml:"SECLIST"`
Securities []Security
}
@ -290,3 +291,42 @@ func (r *SecurityList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) erro
}
}
}
// MarshalXML handles marshalling a SecurityList to an SGML/XML string
func (r *SecurityList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
secListElement := xml.StartElement{Name: xml.Name{Local: "SECLIST"}}
if err := e.EncodeToken(secListElement); err != nil {
return err
}
for _, s := range r.Securities {
start := xml.StartElement{Name: xml.Name{Local: s.SecurityType()}}
switch sec := s.(type) {
case DebtInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
case MFInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
case OptInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
case OtherInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
case StockInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
default:
return errors.New("Invalid SECLIST child type: " + sec.SecurityType())
}
}
if err := e.EncodeToken(secListElement.End()); err != nil {
return err
}
return nil
}

View File

@ -155,4 +155,5 @@ func TestUnmarshalAcctInfoResponse(t *testing.T) {
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}