package handlers //go:generate make import ( "encoding/json" "errors" "gopkg.in/gorp.v1" "log" "net/http" "net/url" "strconv" "strings" ) const ( Currency int64 = 1 Stock = 2 ) func GetSecurityType(typestring string) int64 { if strings.EqualFold(typestring, "currency") { return Currency } else if strings.EqualFold(typestring, "stock") { return Stock } else { return 0 } } type Security struct { SecurityId int64 UserId int64 Name string Description string Symbol string // Number of decimal digits (to the right of the decimal point) this // security is precise to Precision int Type int64 // AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency AlternateId string } type SecurityList struct { Securities *[]*Security `json:"securities"` } func (s *Security) Read(json_str string) error { dec := json.NewDecoder(strings.NewReader(json_str)) return dec.Decode(s) } func (s *Security) Write(w http.ResponseWriter) error { enc := json.NewEncoder(w) return enc.Encode(s) } func (sl *SecurityList) Read(json_str string) error { dec := json.NewDecoder(strings.NewReader(json_str)) return dec.Decode(sl) } func (sl *SecurityList) Write(w http.ResponseWriter) error { enc := json.NewEncoder(w) return enc.Encode(sl) } func SearchSecurityTemplates(search string, _type int64, limit int64) []*Security { upperSearch := strings.ToUpper(search) var results []*Security for i, security := range SecurityTemplates { if strings.Contains(strings.ToUpper(security.Name), upperSearch) || strings.Contains(strings.ToUpper(security.Description), upperSearch) || strings.Contains(strings.ToUpper(security.Symbol), upperSearch) { if _type == 0 || _type == security.Type { results = append(results, &SecurityTemplates[i]) if limit != -1 && int64(len(results)) >= limit { break } } } } return results } func FindSecurityTemplate(name string, _type int64) *Security { for _, security := range SecurityTemplates { if name == security.Name && _type == security.Type { return &security } } return nil } func FindCurrencyTemplate(iso4217 int64) *Security { iso4217string := strconv.FormatInt(iso4217, 10) for _, security := range SecurityTemplates { if security.Type == Currency && security.AlternateId == iso4217string { return &security } } return nil } func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) { var s Security err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) if err != nil { return nil, err } return &s, nil } func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64) (*Security, error) { var s Security err := transaction.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) if err != nil { return nil, err } return &s, nil } func GetSecurities(db *DB, userid int64) (*[]*Security, error) { var securities []*Security _, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid) if err != nil { return nil, err } return &securities, nil } func InsertSecurity(db *DB, s *Security) error { err := db.Insert(s) if err != nil { return err } return nil } func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error { err := transaction.Insert(s) if err != nil { return err } return nil } func UpdateSecurity(db *DB, s *Security) error { transaction, err := db.Begin() if err != nil { return err } user, err := GetUserTx(transaction, s.UserId) if err != nil { transaction.Rollback() return err } else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { transaction.Rollback() return errors.New("Cannot change security which is user's default currency to be non-currency") } count, err := transaction.Update(s) if err != nil { transaction.Rollback() return err } if count != 1 { transaction.Rollback() return errors.New("Updated more than one security") } err = transaction.Commit() if err != nil { transaction.Rollback() return err } return nil } type SecurityInUseError struct { message string } func (e SecurityInUseError) Error() string { return e.message } func DeleteSecurity(db *DB, s *Security) error { transaction, err := db.Begin() if err != nil { return err } // First, ensure no accounts are using this security accounts, err := transaction.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) if accounts != 0 { transaction.Rollback() return SecurityInUseError{"One or more accounts still use this security"} } user, err := GetUserTx(transaction, s.UserId) if err != nil { transaction.Rollback() return err } else if user.DefaultCurrency == s.SecurityId { transaction.Rollback() return SecurityInUseError{"Cannot delete security which is user's default currency"} } // Remove all prices involving this security (either of this security, or // using it as a currency) _, err = transaction.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) if err != nil { transaction.Rollback() return err } count, err := transaction.Delete(s) if err != nil { transaction.Rollback() return err } if count != 1 { transaction.Rollback() return errors.New("Deleted more than one security") } err = transaction.Commit() if err != nil { transaction.Rollback() return err } return nil } func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, security *Security) (*Security, error) { security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId err := InsertSecurityTx(transaction, security) if err != nil { return nil, err } return security, nil } var securities []*Security _, err := transaction.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision) if err != nil { return nil, err } // First try to find a case insensitive match on the name or symbol upperName := strings.ToUpper(security.Name) upperSymbol := strings.ToUpper(security.Symbol) for _, s := range securities { if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) || (len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) { return s, nil } } // if strings.Contains(strings.ToUpper(security.Name), upperSearch) || // Try to find a partial string match on the name or symbol for _, s := range securities { sUpperName := strings.ToUpper(s.Name) sUpperSymbol := strings.ToUpper(s.Symbol) if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) || (len(upperSymbol) > 0 && len(s.Symbol) > 0 && (strings.Contains(upperSymbol, sUpperSymbol) || strings.Contains(sUpperSymbol, upperSymbol))) { return s, nil } } // Give up and return the first security in the list if len(securities) > 0 { return securities[0], nil } // If there wasn't even one security in the list, make a new one err = InsertSecurityTx(transaction, security) if err != nil { return nil, err } return security, nil } func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { user, err := GetUserFromSession(db, r) if err != nil { WriteError(w, 1 /*Not Signed In*/) return } if r.Method == "POST" { security_json := r.PostFormValue("security") if security_json == "" { WriteError(w, 3 /*Invalid Request*/) return } var security Security err := security.Read(security_json) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } security.SecurityId = -1 security.UserId = user.UserId err = InsertSecurity(db, &security) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } w.WriteHeader(201 /*Created*/) err = security.Write(w) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } } else if r.Method == "GET" { var securityid int64 n, err := GetURLPieces(r.URL.Path, "/security/%d", &securityid) if err != nil || n != 1 { //Return all securities var sl SecurityList securities, err := GetSecurities(db, user.UserId) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } sl.Securities = securities err = (&sl).Write(w) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } } else { security, err := GetSecurity(db, securityid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } err = security.Write(w) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } } } else { securityid, err := GetURLID(r.URL.Path) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } if r.Method == "PUT" { security_json := r.PostFormValue("security") if security_json == "" { WriteError(w, 3 /*Invalid Request*/) return } var security Security err := security.Read(security_json) if err != nil || security.SecurityId != securityid { WriteError(w, 3 /*Invalid Request*/) return } security.UserId = user.UserId err = UpdateSecurity(db, &security) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } err = security.Write(w) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } } else if r.Method == "DELETE" { security, err := GetSecurity(db, securityid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } err = DeleteSecurity(db, security) if _, ok := err.(SecurityInUseError); ok { WriteError(w, 7 /*In Use Error*/) } else if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } WriteSuccess(w) } } } func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { var sl SecurityList query, _ := url.ParseQuery(r.URL.RawQuery) var limit int64 = -1 search := query.Get("search") var _type int64 = 0 typestring := query.Get("type") if len(typestring) > 0 { _type = GetSecurityType(typestring) if _type == 0 { WriteError(w, 3 /*Invalid Request*/) return } } limitstring := query.Get("limit") if limitstring != "" { limitint, err := strconv.ParseInt(limitstring, 10, 0) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } limit = limitint } securities := SearchSecurityTemplates(search, _type, limit) sl.Securities = &securities err := (&sl).Write(w) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } } else { WriteError(w, 3 /*Invalid Request*/) } }