Files
micromdm/internal/data/user/user_postgres.go
Victor Vrantchan 2f601dd30d Add FindUser method.
Returns the user by id.
2021-02-20 21:53:58 -05:00

125 lines
2.7 KiB
Go

package user
import (
"context"
"errors"
"fmt"
"strings"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
// Postgres provides methods for creating users in PostgreSQL.
type Postgres struct{ db *pgxpool.Pool }
// NewPostgres creates a Postgres client.
func NewPostgres(db *pgxpool.Pool) *Postgres {
return &Postgres{db: db}
}
// CreateUser creates a new user in the database.
func (d *Postgres) CreateUser(ctx context.Context, username, email, password string) (*User, error) {
u, err := create(username, email, password)
if err != nil {
return nil, fmt.Errorf("create postgres user: %w", err)
}
q :=
`INSERT INTO users (
id, username, email, password, salt, confirmation_hash
) VALUES (
$1, $2, $3, $4, $5, $6
) RETURNING "created_at", "updated_at";`
if err := d.db.QueryRow(ctx, q,
u.ID,
u.Username,
u.Email,
u.Password,
u.Salt,
u.ConfirmationHash,
).Scan(&u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, fmt.Errorf("store created user in postgres: %w", checkPostgres(err))
}
return u, nil
}
func (d *Postgres) ConfirmUser(ctx context.Context, confirmation string) error {
q :=
`UPDATE users
SET confirmation_hash = NULL
WHERE confirmation_hash = $1;`
if tag, err := d.db.Exec(ctx, q, confirmation); err != nil {
return fmt.Errorf("set postgres confirmation_hash to NULL: %w", err)
} else if tag.RowsAffected() == 0 {
return Error{missingHash: confirmation}
}
return nil
}
func (d *Postgres) FindUser(ctx context.Context, id string) (*User, error) {
u := &User{}
q := fmt.Sprintf(`SELECT %s FROM users WHERE id = $1;`, strings.Join(columns(), `, `))
if err := d.db.QueryRow(ctx, q, id).Scan(
&u.ID,
&u.Username,
&u.Email,
&u.Password,
&u.Salt,
&u.ConfirmationHash,
&u.CreatedAt,
&u.UpdatedAt,
); err != nil {
return nil, err
}
return u, nil
}
func (d *Postgres) FindUserByEmail(ctx context.Context, email string) (*User, error) {
if email == "" {
return nil, Error{invalid: constraints["chk_email_not_empty"]}
}
u := &User{}
q := fmt.Sprintf(`SELECT %s FROM users WHERE email = $1;`, strings.Join(columns(), `, `))
if err := d.db.QueryRow(ctx, q, email).Scan(
&u.ID,
&u.Username,
&u.Email,
&u.Password,
&u.Salt,
&u.ConfirmationHash,
&u.CreatedAt,
&u.UpdatedAt,
); err == pgx.ErrNoRows {
return nil, Error{missingEmail: email}
} else if err != nil {
return nil, err
}
return u, nil
}
func checkPostgres(err error) error {
var dbErr *pgconn.PgError
if !errors.As(err, &dbErr) {
return err
}
switch dbErr.Code {
case "23514":
if kv, ok := constraints[dbErr.ConstraintName]; ok {
return Error{invalid: kv}
}
}
return err
}