diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 2925d51..99ea60a 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -180,6 +180,66 @@ func UpdateUser(db *DB, u *User) error { return nil } +func DeleteUser(db *DB, u *User) error { + transaction, err := db.Begin() + if err != nil { + return err + } + + count, err := transaction.Delete(u) + if err != nil { + transaction.Rollback() + return err + } + if count != 1 { + transaction.Rollback() + return fmt.Errorf("No user to delete") + } + _, err = transaction.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId) + if err != nil { + transaction.Rollback() + return err + } + _, err = transaction.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId) + if err != nil { + transaction.Rollback() + return err + } + _, err = transaction.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) + if err != nil { + transaction.Rollback() + return err + } + _, err = transaction.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) + if err != nil { + transaction.Rollback() + return err + } + _, err = transaction.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) + if err != nil { + transaction.Rollback() + return err + } + _, err = transaction.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) + if err != nil { + transaction.Rollback() + return err + } + _, err = transaction.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) + if err != nil { + transaction.Rollback() + return err + } + + err = transaction.Commit() + if err != nil { + transaction.Rollback() + return err + } + + return nil +} + func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { if r.Method == "POST" { user_json := r.PostFormValue("user") @@ -278,13 +338,12 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { return } } else if r.Method == "DELETE" { - count, err := db.Delete(user) - if count != 1 || err != nil { + err := DeleteUser(db, user) + if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } - WriteSuccess(w) } }