Files
micromdm/internal/data/session/session_sqlite.go
Victor Vrantchan 7d75f581ec Added users sessions (#696)
Added login/logout pages and respective database methods for handling users sessions.
2020-08-08 17:15:34 -04:00

120 lines
2.6 KiB
Go

package session
import (
"context"
"errors"
"fmt"
"time"
"crawshaw.io/sqlite/sqlitex"
"micromdm.io/v2/pkg/viewer"
)
// SQLite provides methods for creating users in SQLite.
type SQLite struct{ db *sqlitex.Pool }
// NewSQLite creates a SQLite client.
func NewSQLite(db *sqlitex.Pool) *SQLite {
return &SQLite{db: db}
}
func (d *SQLite) CreateSession(ctx context.Context) (*Session, error) {
s, err := create(ctx)
if err != nil {
return nil, err
}
conn := d.db.Get(ctx)
if conn == nil {
return nil, context.Canceled
}
defer d.db.Put(conn)
stmt := conn.Prep(`INSERT INTO sessions ( id, user_id ) VALUES ( $id, $user_id );`)
stmt.SetText("$id", s.ID)
stmt.SetText("$user_id", s.UserID)
if _, err := stmt.Step(); err != nil {
return nil, fmt.Errorf("store new session in sqlite: %w", err)
}
stmt = conn.Prep(`SELECT created_at, accessed_at FROM sessions WHERE id = $id`)
stmt.SetText("$id", s.ID)
if _, err := stmt.Step(); err != nil {
return nil, err
}
s.CreatedAt, err = time.Parse("2006-01-02 15:04:05", stmt.GetText("created_at"))
if err != nil {
return nil, err
}
s.AccessedAt, err = time.Parse("2006-01-02 15:04:05", stmt.GetText("accessed_at"))
if err != nil {
return nil, err
}
return s, stmt.Reset()
}
func (d *SQLite) DestroySession(ctx context.Context) error {
v, ok := viewer.FromContext(ctx)
if !ok || v.SessionID == "" {
return fmt.Errorf("session: missing valid viewer %v", v)
}
conn := d.db.Get(ctx)
if conn == nil {
return context.Canceled
}
defer d.db.Put(conn)
stmt := conn.Prep(`DELETE FROM sessions WHERE id = $id;`)
stmt.SetText("$id", v.SessionID)
if _, err := stmt.Step(); err != nil {
return err
}
return nil
}
func (d *SQLite) FindSession(ctx context.Context, id string) (*Session, error) {
conn := d.db.Get(ctx)
if conn == nil {
return nil, context.Canceled
}
defer d.db.Put(conn)
stmt := conn.Prep(
`UPDATE sessions
SET accessed_at = CURRENT_TIMESTAMP
WHERE id = $id;`)
stmt.SetText("$id", id)
if _, err := stmt.Step(); err != nil {
return nil, err
}
if conn.Changes() == 0 {
return nil, errors.New("unknown session.id in sqlite")
}
stmt = conn.Prep(`SELECT created_at, accessed_at, user_id FROM sessions WHERE id = $id`)
stmt.SetText("$id", id)
if _, err := stmt.Step(); err != nil {
return nil, err
}
s := Session{ID: id}
var err error
s.CreatedAt, err = time.Parse("2006-01-02 15:04:05", stmt.GetText("created_at"))
if err != nil {
return nil, err
}
s.AccessedAt, err = time.Parse("2006-01-02 15:04:05", stmt.GetText("accessed_at"))
if err != nil {
return nil, err
}
s.UserID = stmt.GetText("user_id")
return &s, stmt.Reset()
}