Merge pull request #16 from aclindsa/writable_responses

Allow for marshalling Response objects to strings containing XML/SGML
This commit is contained in:
Aaron Lindsay 2019-03-02 15:33:38 -05:00 committed by GitHub
commit 3e8a9c5a53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 407 additions and 23 deletions

View File

@ -283,4 +283,5 @@ func TestUnmarshalBankStatementResponse(t *testing.T) {
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}

View File

@ -3,10 +3,36 @@ package ofxgo
//go:generate ./generate_constants.py
import (
"bytes"
"errors"
"fmt"
"github.com/aclindsa/xml"
)
func writeHeader(b *bytes.Buffer, v ofxVersion) error {
// Write the header appropriate to our version
switch v {
case OfxVersion102, OfxVersion103, OfxVersion151, OfxVersion160:
b.WriteString(`OFXHEADER:100
DATA:OFXSGML
VERSION:` + v.String() + `
SECURITY:NONE
ENCODING:USASCII
CHARSET:1252
COMPRESSION:NONE
OLDFILEUID:NONE
NEWFILEUID:NONE
`)
case OfxVersion200, OfxVersion201, OfxVersion202, OfxVersion203, OfxVersion210, OfxVersion211, OfxVersion220:
b.WriteString(`<?xml version="1.0" encoding="UTF-8" standalone="no"?>` + "\n")
b.WriteString(`<?OFX OFXHEADER="200" VERSION="` + v.String() + `" SECURITY="NONE" OLDFILEUID="NONE" NEWFILEUID="NONE"?>` + "\n")
default:
return fmt.Errorf("%d is not a valid OFX version string", v)
}
return nil
}
// Message represents an OFX message in a message set. it is used to ease
// marshalling and unmarshalling.
type Message interface {

View File

@ -161,4 +161,5 @@ NEWFILEUID:NONE
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}

1
go.mod
View File

@ -2,6 +2,7 @@ module github.com/aclindsa/ofxgo
require (
github.com/aclindsa/xml v0.0.0-20171002130543-5d4402bb4a20
github.com/howeyc/gopass v0.0.0-20170109162249-bf9dde6d0d2c
golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4 // indirect
golang.org/x/sys v0.0.0-20180928133829-e4b3c5e90611 // indirect
golang.org/x/text v0.0.0-20180911161511-905a57155faa

2
go.sum
View File

@ -1,5 +1,7 @@
github.com/aclindsa/xml v0.0.0-20171002130543-5d4402bb4a20 h1:wN3KlzWq56AIgOqFzYLYVih4zVyPDViCUeG5uZxJHq4=
github.com/aclindsa/xml v0.0.0-20171002130543-5d4402bb4a20/go.mod h1:DiEHtTD+e6zS3+R95F05Bfbcsfv13wZTi2M4LfAFLBE=
github.com/howeyc/gopass v0.0.0-20170109162249-bf9dde6d0d2c h1:kQWxfPIHVLbgLzphqk3QUflDy9QdksZR4ygR807bpy0=
github.com/howeyc/gopass v0.0.0-20170109162249-bf9dde6d0d2c/go.mod h1:lADxMC39cJJqL93Duh1xhAs4I2Zs8mKS89XWXFGp9cs=
golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4 h1:Vk3wNqEZwyGyei9yq5ekj7frek2u7HUfffJ1/opblzc=
golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/sys v0.0.0-20180928133829-e4b3c5e90611 h1:O33LKL7WyJgjN9CvxfTIomjIClbd/Kq86/iipowHQU0=

View File

@ -465,6 +465,7 @@ type InvBankTransaction struct {
// security-related transactions themselves. It must be unmarshalled manually
// due to the structure (don't know what kind of InvTransaction is coming next)
type InvTranList struct {
XMLName xml.Name `xml:"INVTRANLIST"`
DtStart Date
DtEnd Date // This is the value that should be sent as <DTSTART> in the next InvStatementRequest to ensure that no transactions are missed
InvTransactions []InvTransaction
@ -630,6 +631,119 @@ func (l *InvTranList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error
}
}
// MarshalXML handles marshalling an InvTranList element to an SGML/XML string
func (l *InvTranList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
invTranListElement := xml.StartElement{Name: xml.Name{Local: "INVTRANLIST"}}
if err := e.EncodeToken(invTranListElement); err != nil {
return err
}
err := e.EncodeElement(&l.DtStart, xml.StartElement{Name: xml.Name{Local: "DTSTART"}})
if err != nil {
return err
}
err = e.EncodeElement(&l.DtEnd, xml.StartElement{Name: xml.Name{Local: "DTEND"}})
if err != nil {
return err
}
for _, t := range l.InvTransactions {
start := xml.StartElement{Name: xml.Name{Local: t.TransactionType()}}
switch tran := t.(type) {
case BuyDebt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case BuyMF:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case BuyOpt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case BuyOther:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case BuyStock:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case ClosureOpt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case Income:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case InvExpense:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case JrnlFund:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case JrnlSec:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case MarginInterest:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case Reinvest:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case RetOfCap:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellDebt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellMF:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellOpt:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellOther:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case SellStock:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case Split:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
case Transfer:
if err := e.EncodeElement(&tran, start); err != nil {
return err
}
default:
return errors.New("Invalid INVTRANLIST child type: " + tran.TransactionType())
}
}
for _, tran := range l.BankTransactions {
err = e.EncodeElement(&tran, xml.StartElement{Name: xml.Name{Local: "INVBANKTRAN"}})
if err != nil {
return err
}
}
if err := e.EncodeToken(invTranListElement.End()); err != nil {
return err
}
return nil
}
// InvPosition contains generic position information included in each of the
// other *Position types
type InvPosition struct {
@ -770,6 +884,45 @@ func (p *PositionList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) erro
}
}
// MarshalXML handles marshalling a PositionList to an XML string
func (p *PositionList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
invPosListElement := xml.StartElement{Name: xml.Name{Local: "INVPOSLIST"}}
if err := e.EncodeToken(invPosListElement); err != nil {
return err
}
for _, position := range *p {
start := xml.StartElement{Name: xml.Name{Local: position.PositionType()}}
switch pos := position.(type) {
case DebtPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
case MFPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
case OptPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
case OtherPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
case StockPosition:
if err := e.EncodeElement(&pos, start); err != nil {
return err
}
default:
return errors.New("Invalid INVPOSLIST child type: " + pos.PositionType())
}
}
if err := e.EncodeToken(invPosListElement.End()); err != nil {
return err
}
return nil
}
// InvBalance contains three (or optionally four) specified balances as well as
// a free-form list of generic balance information which may be provided by an
// FI.
@ -1036,6 +1189,69 @@ func (o *OOList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
}
}
// MarshalXML handles marshalling an OOList to an XML string
func (o *OOList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
ooListElement := xml.StartElement{Name: xml.Name{Local: "INVOOLIST"}}
if err := e.EncodeToken(ooListElement); err != nil {
return err
}
for _, openorder := range *o {
start := xml.StartElement{Name: xml.Name{Local: openorder.OrderType()}}
switch oo := openorder.(type) {
case OOBuyDebt:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOBuyMF:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOBuyOpt:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOBuyOther:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOBuyStock:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellDebt:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellMF:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellOpt:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellOther:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSellStock:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
case OOSwitchMF:
if err := e.EncodeElement(&oo, start); err != nil {
return err
}
default:
return errors.New("Invalid OOLIST child type: " + oo.OrderType())
}
}
if err := e.EncodeToken(ooListElement.End()); err != nil {
return err
}
return nil
}
// ContribSecurity identifies current contribution allocation for a security in
// a 401(k) account
type ContribSecurity struct {

View File

@ -602,6 +602,7 @@ func TestUnmarshalInvStatementResponse(t *testing.T) {
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}
func TestUnmarshalInvStatementResponse102(t *testing.T) {
@ -957,6 +958,7 @@ NEWFILEUID: NONE
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}
func TestUnmarshalInvTranList(t *testing.T) {

View File

@ -3,6 +3,7 @@ package ofxgo
import (
"errors"
"github.com/aclindsa/xml"
"strings"
)
// ProfileRequest represents a request for a server to provide a profile of its
@ -126,6 +127,35 @@ func (msl *MessageSetList) UnmarshalXML(d *xml.Decoder, start xml.StartElement)
}
}
// MarshalXML handles marshalling a MessageSetList element to an XML string
func (msl *MessageSetList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
messageSetListElement := xml.StartElement{Name: xml.Name{Local: "MSGSETLIST"}}
if err := e.EncodeToken(messageSetListElement); err != nil {
return err
}
for _, messageset := range *msl {
if !strings.HasSuffix(messageset.Name, "V1") {
return errors.New("Expected MessageSet.Name to end with \"V1\"")
}
messageSetName := strings.TrimSuffix(messageset.Name, "V1")
messageSetElement := xml.StartElement{Name: xml.Name{Local: messageSetName}}
if err := e.EncodeToken(messageSetElement); err != nil {
return err
}
start := xml.StartElement{Name: xml.Name{Local: messageset.Name}}
if err := e.EncodeElement(&messageset, start); err != nil {
return err
}
if err := e.EncodeToken(messageSetElement.End()); err != nil {
return err
}
}
if err := e.EncodeToken(messageSetListElement.End()); err != nil {
return err
}
return nil
}
// ProfileResponse contains a requested profile of the server's capabilities
// (which message sets and versions it supports, how to access them, which
// languages and which types of synchronization they support, etc.). Note that

View File

@ -325,4 +325,5 @@ NEWFILEUID:NONE
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}

View File

@ -3,7 +3,6 @@ package ofxgo
import (
"bytes"
"errors"
"fmt"
"github.com/aclindsa/xml"
"time"
)
@ -35,7 +34,7 @@ type Request struct {
indent bool // Whether to indent the marshaled XML
}
func marshalMessageSet(e *xml.Encoder, requests []Message, set messageType, version ofxVersion) error {
func encodeMessageSet(e *xml.Encoder, requests []Message, set messageType, version ofxVersion) error {
if len(requests) > 0 {
messageSetElement := xml.StartElement{Name: xml.Name{Local: set.String()}}
if err := e.EncodeToken(messageSetElement); err != nil {
@ -80,25 +79,7 @@ func (oq *Request) Marshal() (*bytes.Buffer, error) {
var b bytes.Buffer
// Write the header appropriate to our version
switch oq.Version {
case OfxVersion102, OfxVersion103, OfxVersion151, OfxVersion160:
b.WriteString(`OFXHEADER:100
DATA:OFXSGML
VERSION:` + oq.Version.String() + `
SECURITY:NONE
ENCODING:USASCII
CHARSET:1252
COMPRESSION:NONE
OLDFILEUID:NONE
NEWFILEUID:NONE
`)
case OfxVersion200, OfxVersion201, OfxVersion202, OfxVersion203, OfxVersion210, OfxVersion211, OfxVersion220:
b.WriteString(`<?xml version="1.0" encoding="UTF-8" standalone="no"?>` + "\n")
b.WriteString(`<?OFX OFXHEADER="200" VERSION="` + oq.Version.String() + `" SECURITY="NONE" OLDFILEUID="NONE" NEWFILEUID="NONE"?>` + "\n")
default:
return nil, fmt.Errorf("%d is not a valid OFX version string", oq.Version)
}
writeHeader(&b, oq.Version)
encoder := xml.NewEncoder(&b)
if oq.indent {
@ -145,7 +126,7 @@ NEWFILEUID:NONE
{oq.Image, ImageRq},
}
for _, set := range messageSets {
if err := marshalMessageSet(encoder, set.Messages, set.Type, oq.Version); err != nil {
if err := encodeMessageSet(encoder, set.Messages, set.Type, oq.Version); err != nil {
return nil, err
}
}

View File

@ -369,3 +369,70 @@ func ParseResponse(reader io.Reader) (*Response, error) {
}
}
}
// Marshal this Response into its SGML/XML representation held in a bytes.Buffer
//
// If error is non-nil, this bytes.Buffer is ready to be sent to an OFX client
func (or *Response) Marshal() (*bytes.Buffer, error) {
var b bytes.Buffer
// Write the header appropriate to our version
writeHeader(&b, or.Version)
encoder := xml.NewEncoder(&b)
encoder.Indent("", " ")
ofxElement := xml.StartElement{Name: xml.Name{Local: "OFX"}}
if err := encoder.EncodeToken(ofxElement); err != nil {
return nil, err
}
if ok, err := or.Signon.Valid(or.Version); !ok {
return nil, err
}
signonMsgSet := xml.StartElement{Name: xml.Name{Local: SignonRs.String()}}
if err := encoder.EncodeToken(signonMsgSet); err != nil {
return nil, err
}
if err := encoder.Encode(&or.Signon); err != nil {
return nil, err
}
if err := encoder.EncodeToken(signonMsgSet.End()); err != nil {
return nil, err
}
messageSets := []struct {
Messages []Message
Type messageType
}{
{or.Signup, SignupRs},
{or.Bank, BankRs},
{or.CreditCard, CreditCardRs},
{or.Loan, LoanRs},
{or.InvStmt, InvStmtRs},
{or.InterXfer, InterXferRs},
{or.WireXfer, WireXferRs},
{or.Billpay, BillpayRs},
{or.Email, EmailRs},
{or.SecList, SecListRs},
{or.PresDir, PresDirRs},
{or.PresDlv, PresDlvRs},
{or.Prof, ProfRs},
{or.Image, ImageRs},
}
for _, set := range messageSets {
if err := encodeMessageSet(encoder, set.Messages, set.Type, or.Version); err != nil {
return nil, err
}
}
if err := encoder.EncodeToken(ofxElement.End()); err != nil {
return nil, err
}
if err := encoder.Flush(); err != nil {
return nil, err
}
return &b, nil
}

View File

@ -136,6 +136,20 @@ func checkResponsesEqual(t *testing.T, expected, actual *ofxgo.Response) {
checkEqual(t, "", reflect.ValueOf(expected), reflect.ValueOf(actual))
}
func checkResponseRoundTrip(t *testing.T, response *ofxgo.Response) {
b, err := response.Marshal()
if err != nil {
t.Fatalf("Unexpected error re-marshaling OFX response: %s\n", err)
}
roundtripped, err := ofxgo.ParseResponse(b)
if err != nil {
t.Fatalf("Unexpected error re-parsing OFX response: %s\n", err)
}
checkResponsesEqual(t, response, roundtripped)
}
// Ensure that these samples both parse without errors, and can be converted
// back and forth without changing.
func TestValidSamples(t *testing.T) {
fn := func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
@ -147,10 +161,11 @@ func TestValidSamples(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error opening %s: %s\n", path, err)
}
_, err = ofxgo.ParseResponse(file)
response, err := ofxgo.ParseResponse(file)
if err != nil {
t.Fatalf("Unexpected error parsing OFX response in %s: %s\n", path, err)
}
checkResponseRoundTrip(t, response)
return nil
}
filepath.Walk("samples/valid_responses", fn)

View File

@ -221,6 +221,7 @@ func (i StockInfo) SecurityType() string {
// SecurityList is a container for Security objects containaing information
// about securities
type SecurityList struct {
XMLName xml.Name `xml:"SECLIST"`
Securities []Security
}
@ -290,3 +291,42 @@ func (r *SecurityList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) erro
}
}
}
// MarshalXML handles marshalling a SecurityList to an SGML/XML string
func (r *SecurityList) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
secListElement := xml.StartElement{Name: xml.Name{Local: "SECLIST"}}
if err := e.EncodeToken(secListElement); err != nil {
return err
}
for _, s := range r.Securities {
start := xml.StartElement{Name: xml.Name{Local: s.SecurityType()}}
switch sec := s.(type) {
case DebtInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
case MFInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
case OptInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
case OtherInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
case StockInfo:
if err := e.EncodeElement(&sec, start); err != nil {
return err
}
default:
return errors.New("Invalid SECLIST child type: " + sec.SecurityType())
}
}
if err := e.EncodeToken(secListElement.End()); err != nil {
return err
}
return nil
}

View File

@ -155,4 +155,5 @@ func TestUnmarshalAcctInfoResponse(t *testing.T) {
}
checkResponsesEqual(t, &expected, response)
checkResponseRoundTrip(t, response)
}