From 257495a343a23cab4480f39b44352e1fd6c1683d Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Mon, 13 Mar 2017 21:09:15 -0400 Subject: [PATCH] Add testing for basic types, fix some bugs --- types.go | 75 +++++++++++--- types_test.go | 274 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 336 insertions(+), 13 deletions(-) create mode 100644 types_test.go diff --git a/types.go b/types.go index 8829f91..ef4d6ec 100644 --- a/types.go +++ b/types.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/golang/go/src/encoding/xml" + "math/big" "regexp" "strconv" "strings" @@ -13,23 +14,36 @@ import ( type Int int64 -func (oi *Int) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { +type Amount big.Rat + +func (a *Amount) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { var value string + var b big.Rat + err := d.DecodeElement(&value, &start) if err != nil { return err } - i, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64) - if err != nil { - return err + + // The OFX spec allows the start of the fractional amount to be delineated + // by a comma, so fix that up before attempting to parse it into big.Rat + value = strings.Replace(value, ",", ".", 1) + + if _, ok := b.SetString(value); !ok { + return errors.New("Failed to parse OFX amount into big.Rat") } - *oi = (Int)(i) + *a = Amount(b) return nil } -type Amount string +func (a Amount) String() string { + var b big.Rat = big.Rat(a) + return strings.TrimRight(strings.TrimRight(b.FloatString(100), "0"), ".") +} -// TODO parse Amount into big.Rat? +func (a *Amount) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + return e.EncodeElement(a.String(), start) +} type Date time.Time @@ -67,15 +81,30 @@ func (od *Date) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { return err } if len(matches[3]) > 0 { - zoneminutes, err = strconv.Atoi(matches[1]) + zoneminutes, err = strconv.Atoi(matches[3]) if err != nil { return err } zoneminutes = zoneminutes * 60 / 100 } zone = fmt.Sprintf(" %+03d%02d", zonehours, zoneminutes) + + // Get the time zone name if it's there, default to GMT if the offset + // is 0 and a name isn't supplied + if len(matches[5]) > 0 { + zone = zone + " " + matches[5] + zoneFormat = zoneFormat + " MST" + } else if zonehours == 0 && zoneminutes == 0 { + zone = zone + " GMT" + zoneFormat = zoneFormat + " MST" + } + } else { + // Default to GMT if no time zone was specified + zone = " +0000 GMT" + zoneFormat = " -0700 MST" } + // Try all the date formats, from longest to shortest for _, format := range ofxDateFormats { t, err := time.Parse(format+zoneFormat, value+zone) if err == nil { @@ -91,14 +120,23 @@ func (od Date) String() string { t := time.Time(od) format := t.Format(ofxDateFormats[0]) zonename, zoneoffset := t.Zone() - format += "[" + fmt.Sprintf("%+d", zoneoffset/3600) - fractionaloffset := (zoneoffset % 3600) / 360 + if zoneoffset < 0 { + format += "[" + fmt.Sprintf("%+d", zoneoffset/3600) + } else { + format += "[" + fmt.Sprintf("%d", zoneoffset/3600) + } + fractionaloffset := (zoneoffset % 3600) / 36 if fractionaloffset > 0 { format += "." + fmt.Sprintf("%02d", fractionaloffset) } else if fractionaloffset < 0 { format += "." + fmt.Sprintf("%02d", -fractionaloffset) } - return format + ":" + zonename + "]" + + if len(zonename) > 0 { + return format + ":" + zonename + "]" + } else { + return format + "]" + } } func (od *Date) MarshalXML(e *xml.Encoder, start xml.StartElement) error { @@ -117,6 +155,10 @@ func (os *String) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { return nil } +func (os *String) String() string { + return string(*os) +} + type Boolean bool func (ob *Boolean) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { @@ -144,12 +186,19 @@ func (ob *Boolean) MarshalXML(e *xml.Encoder, start xml.StartElement) error { return e.EncodeElement("N", start) } +func (ob *Boolean) String() string { + return fmt.Sprintf("%v", *ob) +} + type UID string -func (ou *UID) Valid() (bool, error) { - if len(*ou) != 36 { +func (ou UID) Valid() (bool, error) { + if len(ou) != 36 { return false, errors.New("UID not 36 characters long") } + if ou[8] != '-' || ou[13] != '-' || ou[18] != '-' || ou[23] != '-' { + return false, errors.New("UID missing hyphens at the appropriate places") + } return true, nil } diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..2cda2d3 --- /dev/null +++ b/types_test.go @@ -0,0 +1,274 @@ +package ofxgo_test + +import ( + "fmt" + "github.com/aclindsa/ofxgo" + "github.com/golang/go/src/encoding/xml" + "math/big" + "reflect" + "testing" + "time" +) + +func getTypeName(i interface{}) string { + val := reflect.ValueOf(i) + + // Do the same thing that encoding/xml does to get the name + for val.Kind() == reflect.Interface || val.Kind() == reflect.Ptr { + if val.IsNil() { + return "" + } + val = val.Elem() + } + return val.Type().Name() +} + +func marshalHelper(t *testing.T, expected string, i interface{}) { + typename := getTypeName(i) + expectedstring := fmt.Sprintf("<%s>%s", typename, expected, typename) + b, err := xml.Marshal(i) + if err != nil { + t.Fatalf("Unexpected error on xml.Marshal(%T): %s\n", i, err) + } + if string(b) != expectedstring { + t.Fatalf("Expected '%s', got '%s'\n", expectedstring, string(b)) + } +} + +func unmarshalHelper2(t *testing.T, input string, expected interface{}, overwritten interface{}, eq func(a, b interface{}) bool) { + typename := getTypeName(expected) + inputstring := fmt.Sprintf("<%s>%s", typename, input, typename) + err := xml.Unmarshal([]byte(inputstring), &overwritten) + if err != nil { + t.Fatalf("Unexpected error on xml.Unmarshal(%T): %s\n", expected, err) + } + if !eq(overwritten, expected) { + t.Fatalf("Expected '%s', got '%s'\n", expected, overwritten) + } +} + +func unmarshalHelper(t *testing.T, input string, expected interface{}, overwritten interface{}) { + eq := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + unmarshalHelper2(t, input, expected, overwritten, eq) +} + +func TestMarshalInt(t *testing.T) { + var i ofxgo.Int = 927 + marshalHelper(t, "927", &i) + i = 0 + marshalHelper(t, "0", &i) + i = -768276587425 + marshalHelper(t, "-768276587425", &i) +} + +func TestUnmarshalInt(t *testing.T) { + var i, overwritten ofxgo.Int = -48394, 0 + unmarshalHelper(t, "-48394", &i, &overwritten) + i = 0 + unmarshalHelper(t, "0", &i, &overwritten) + i = 198237198 + unmarshalHelper(t, "198237198", &i, &overwritten) +} + +func TestMarshalAmount(t *testing.T) { + var a ofxgo.Amount + var b *big.Rat = (*big.Rat)(&a) + + b.SetFrac64(8, 1) + marshalHelper(t, "8", &a) + b.SetFrac64(1, 8) + marshalHelper(t, "0.125", &a) + b.SetFrac64(-1, 200) + marshalHelper(t, "-0.005", &a) + b.SetInt64(0) + marshalHelper(t, "0", &a) + b.SetInt64(-768276587425) + marshalHelper(t, "-768276587425", &a) + b.SetFrac64(1, 12) + marshalHelper(t, "0.0833333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333", &a) +} + +func TestUnmarshalAmount(t *testing.T) { + var a, overwritten ofxgo.Amount + var b *big.Rat = (*big.Rat)(&a) + + // Amount/big.Rat needs a special equality test because reflect.DeepEqual + // doesn't always return equal for two values that big.Rat.Cmp() does + eq := func(a, b interface{}) bool { + if amountA, ok := a.(*ofxgo.Amount); ok { + if amountB, ok2 := b.(*ofxgo.Amount); ok2 { + ratA := (*big.Rat)(amountA) + return ratA.Cmp((*big.Rat)(amountB)) == 0 + } + } + return false + } + + b.SetFrac64(12, 1) + unmarshalHelper2(t, "12", &a, &overwritten, eq) + b.SetFrac64(-21309, 100) + unmarshalHelper2(t, "-213.09", &a, &overwritten, eq) + b.SetFrac64(8192, 1000) + unmarshalHelper2(t, "8.192", &a, &overwritten, eq) + unmarshalHelper2(t, "+8.192", &a, &overwritten, eq) + b.SetInt64(0) + unmarshalHelper2(t, "0", &a, &overwritten, eq) + unmarshalHelper2(t, "+0", &a, &overwritten, eq) + unmarshalHelper2(t, "-0", &a, &overwritten, eq) + b.SetInt64(-19487135) + unmarshalHelper2(t, "-19487135", &a, &overwritten, eq) +} + +func TestMarshalDate(t *testing.T) { + var d ofxgo.Date + UTC := time.FixedZone("UTC", 0) + GMT_nodesc := time.FixedZone("", 0) + EST := time.FixedZone("EST", -5*60*60) + NPT := time.FixedZone("NPT", (5*60+45)*60) + IST := time.FixedZone("IST", (5*60+30)*60) + NST := time.FixedZone("NST", -(3*60+30)*60) + + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, NPT)) + marshalHelper(t, "20170314150926.053[5.75:NPT]", &d) + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, EST)) + marshalHelper(t, "20170314150926.053[-5:EST]", &d) + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, UTC)) + marshalHelper(t, "20170314150926.053[0:UTC]", &d) + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, IST)) + marshalHelper(t, "20170314150926.053[5.50:IST]", &d) + d = ofxgo.Date(time.Date(9999, 11, 1, 23, 59, 59, 1000, EST)) + marshalHelper(t, "99991101235959.000[-5:EST]", &d) + d = ofxgo.Date(time.Date(0, 1, 1, 0, 0, 0, 0, IST)) + marshalHelper(t, "00000101000000.000[5.50:IST]", &d) + d = ofxgo.Date(time.Unix(0, 0).In(UTC)) + marshalHelper(t, "19700101000000.000[0:UTC]", &d) + d = ofxgo.Date(time.Date(2017, 3, 14, 0, 0, 26, 53*1000*1000, EST)) + marshalHelper(t, "20170314000026.053[-5:EST]", &d) + d = ofxgo.Date(time.Date(2017, 3, 14, 0, 0, 26, 53*1000*1000, NST)) + marshalHelper(t, "20170314000026.053[-3.50:NST]", &d) + + // Time zone without textual description + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, GMT_nodesc)) + marshalHelper(t, "20170314150926.053[0]", &d) +} + +func TestUnmarshalDate(t *testing.T) { + var d, overwritten ofxgo.Date + GMT := time.FixedZone("GMT", 0) + EST := time.FixedZone("EST", -5*60*60) + NPT := time.FixedZone("NPT", (5*60+45)*60) + IST := time.FixedZone("IST", (5*60+30)*60) + NST := time.FixedZone("NST", -(3*60+30)*60) + NST_nodesc := time.FixedZone("", -(3*60+30)*60) + + // Ensure omitted fields default to the correct values + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, GMT)) + unmarshalHelper(t, "20170314150926.053[0]", &d, &overwritten) + unmarshalHelper(t, "20170314150926.053", &d, &overwritten) + d = ofxgo.Date(time.Date(2017, 3, 14, 0, 0, 0, 0, GMT)) + unmarshalHelper(t, "20170314", &d, &overwritten) + + // Ensure all signs on time zone offsets are properly handled + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, GMT)) + unmarshalHelper(t, "20170314150926.053[0:GMT]", &d, &overwritten) + unmarshalHelper(t, "20170314150926.053[+0:GMT]", &d, &overwritten) + unmarshalHelper(t, "20170314150926.053[-0:GMT]", &d, &overwritten) + unmarshalHelper(t, "20170314150926.053[0]", &d, &overwritten) + unmarshalHelper(t, "20170314150926.053[+0]", &d, &overwritten) + unmarshalHelper(t, "20170314150926.053[-0]", &d, &overwritten) + + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, NPT)) + unmarshalHelper(t, "20170314150926.053[5.75:NPT]", &d, &overwritten) + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, EST)) + unmarshalHelper(t, "20170314150926.053[-5:EST]", &d, &overwritten) + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, GMT)) + unmarshalHelper(t, "20170314150926.053[0:GMT]", &d, &overwritten) + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, IST)) + unmarshalHelper(t, "20170314150926.053[5.50:IST]", &d, &overwritten) + d = ofxgo.Date(time.Date(2018, 11, 1, 23, 59, 58, 0, EST)) + unmarshalHelper(t, "20181101235958.000[-5:EST]", &d, &overwritten) + d = ofxgo.Date(time.Date(0, 1, 1, 0, 0, 0, 0, IST)) + unmarshalHelper(t, "00000101000000.000[5.50:IST]", &d, &overwritten) + d = ofxgo.Date(time.Unix(0, 0).In(GMT)) + unmarshalHelper(t, "19700101000000.000[0:GMT]", &d, &overwritten) + d = ofxgo.Date(time.Date(2017, 3, 14, 0, 0, 26, 53*1000*1000, EST)) + unmarshalHelper(t, "20170314000026.053[-5:EST]", &d, &overwritten) + d = ofxgo.Date(time.Date(2017, 3, 14, 0, 0, 26, 53*1000*1000, NST)) + unmarshalHelper(t, "20170314000026.053[-3.50:NST]", &d, &overwritten) + + // Autopopulate zone without textual description for GMT + d = ofxgo.Date(time.Date(2017, 3, 14, 15, 9, 26, 53*1000*1000, GMT)) + unmarshalHelper(t, "20170314150926.053[0]", &d, &overwritten) + // but not for others: + d = ofxgo.Date(time.Date(2017, 3, 14, 0, 0, 26, 53*1000*1000, NST_nodesc)) + unmarshalHelper(t, "20170314000026.053[-3.50]", &d, &overwritten) +} + +func TestMarshalString(t *testing.T) { + var s ofxgo.String = "" + marshalHelper(t, "", &s) + s = "foo&bar" + marshalHelper(t, "foo&bar", &s) + s = "\n" + marshalHelper(t, " ", &s) + s = "Some Name" + marshalHelper(t, "Some Name", &s) +} + +func TestUnmarshalString(t *testing.T) { + var s, overwritten ofxgo.String = "", "" + unmarshalHelper(t, "", &s, &overwritten) + s = "foo&bar" + unmarshalHelper(t, "foo&bar", &s, &overwritten) + // whitespace intentionally stripped because some OFX servers add newlines + // inside tags + s = "new\nline" + unmarshalHelper(t, " new line ", &s, &overwritten) + s = "Some Name" + unmarshalHelper(t, "Some Name", &s, &overwritten) +} + +func TestMarshalBoolean(t *testing.T) { + var b ofxgo.Boolean = true + marshalHelper(t, "Y", &b) + b = false + marshalHelper(t, "N", &b) +} + +func TestUnmarshalBoolean(t *testing.T) { + var b, overwritten ofxgo.Boolean = true, false + unmarshalHelper(t, "Y", &b, &overwritten) + b = false + unmarshalHelper(t, "N", &b, &overwritten) +} + +func TestMarshalUID(t *testing.T) { + var u ofxgo.UID = "d1cf3d3d-9ef9-4a97-b180-81706829cb04" + marshalHelper(t, "d1cf3d3d-9ef9-4a97-b180-81706829cb04", &u) +} + +func TestUnmarshalUID(t *testing.T) { + var u, overwritten ofxgo.UID = "d1cf3d3d-9ef9-4a97-b180-81706829cb04", "" + unmarshalHelper(t, "d1cf3d3d-9ef9-4a97-b180-81706829cb04", &u, &overwritten) +} + +func TestUIDValid(t *testing.T) { + var u ofxgo.UID = "d1cf3d3d-9ef9-4a97-b180-81706829cb04" + if ok, err := u.Valid(); !ok || err != nil { + t.Fatalf("UID unexpectedly failed validation\n") + } + u = "d1cf3d3d-9ef9-4a97-b180-81706829cb0" + if ok, err := u.Valid(); ok || err == nil { + t.Fatalf("UID should have failed validation because it's too short\n") + } + u = "d1cf3d3d-9ef94a97-b180-81706829cb04" + if ok, err := u.Valid(); ok || err == nil { + t.Fatalf("UID should have failed validation because it's missing hyphens\n") + } + u = "d1cf3d3d-9ef9-4a97-b180981706829cb04" + if ok, err := u.Valid(); ok || err == nil { + t.Fatalf("UID should have failed validation because its hyphens have been replaced\n") + } +}