mirror of
https://github.com/micromdm/micromdm/
synced 2026-06-14 14:55:35 +08:00
125 lines
2.7 KiB
Go
125 lines
2.7 KiB
Go
package user
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgconn"
|
|
"github.com/jackc/pgx/v4"
|
|
"github.com/jackc/pgx/v4/pgxpool"
|
|
)
|
|
|
|
// Postgres provides methods for creating users in PostgreSQL.
|
|
type Postgres struct{ db *pgxpool.Pool }
|
|
|
|
// NewPostgres creates a Postgres client.
|
|
func NewPostgres(db *pgxpool.Pool) *Postgres {
|
|
return &Postgres{db: db}
|
|
}
|
|
|
|
// CreateUser creates a new user in the database.
|
|
func (d *Postgres) CreateUser(ctx context.Context, username, email, password string) (*User, error) {
|
|
u, err := create(username, email, password)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create postgres user: %w", err)
|
|
}
|
|
|
|
q :=
|
|
`INSERT INTO users (
|
|
id, username, email, password, salt, confirmation_hash
|
|
) VALUES (
|
|
$1, $2, $3, $4, $5, $6
|
|
) RETURNING "created_at", "updated_at";`
|
|
|
|
if err := d.db.QueryRow(ctx, q,
|
|
u.ID,
|
|
u.Username,
|
|
u.Email,
|
|
u.Password,
|
|
u.Salt,
|
|
u.ConfirmationHash,
|
|
).Scan(&u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("store created user in postgres: %w", checkPostgres(err))
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func (d *Postgres) ConfirmUser(ctx context.Context, confirmation string) error {
|
|
q :=
|
|
`UPDATE users
|
|
SET confirmation_hash = NULL
|
|
WHERE confirmation_hash = $1;`
|
|
|
|
if tag, err := d.db.Exec(ctx, q, confirmation); err != nil {
|
|
return fmt.Errorf("set postgres confirmation_hash to NULL: %w", err)
|
|
} else if tag.RowsAffected() == 0 {
|
|
return Error{missingHash: confirmation}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *Postgres) FindUser(ctx context.Context, id string) (*User, error) {
|
|
|
|
u := &User{}
|
|
q := fmt.Sprintf(`SELECT %s FROM users WHERE id = $1;`, strings.Join(columns(), `, `))
|
|
if err := d.db.QueryRow(ctx, q, id).Scan(
|
|
&u.ID,
|
|
&u.Username,
|
|
&u.Email,
|
|
&u.Password,
|
|
&u.Salt,
|
|
&u.ConfirmationHash,
|
|
&u.CreatedAt,
|
|
&u.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func (d *Postgres) FindUserByEmail(ctx context.Context, email string) (*User, error) {
|
|
if email == "" {
|
|
return nil, Error{invalid: constraints["chk_email_not_empty"]}
|
|
}
|
|
|
|
u := &User{}
|
|
q := fmt.Sprintf(`SELECT %s FROM users WHERE email = $1;`, strings.Join(columns(), `, `))
|
|
if err := d.db.QueryRow(ctx, q, email).Scan(
|
|
&u.ID,
|
|
&u.Username,
|
|
&u.Email,
|
|
&u.Password,
|
|
&u.Salt,
|
|
&u.ConfirmationHash,
|
|
&u.CreatedAt,
|
|
&u.UpdatedAt,
|
|
); err == pgx.ErrNoRows {
|
|
return nil, Error{missingEmail: email}
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func checkPostgres(err error) error {
|
|
var dbErr *pgconn.PgError
|
|
if !errors.As(err, &dbErr) {
|
|
return err
|
|
}
|
|
|
|
switch dbErr.Code {
|
|
case "23514":
|
|
if kv, ok := constraints[dbErr.ConstraintName]; ok {
|
|
return Error{invalid: kv}
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|