Refactor MDM service to return a payload from check-ins (#764)

* Refactor MDM service to return a payload from check-ins

* Documented public methods, add documentation to main service interface

* Remove interface-implementing comments
This commit is contained in:
Jesse Peterson
2021-07-01 16:13:39 -07:00
committed by GitHub
parent 1dfa0a8fd0
commit 4e684a40d2
5 changed files with 32 additions and 23 deletions

View File

@@ -10,31 +10,31 @@ import (
"github.com/pkg/errors"
)
func (svc *MDMService) Checkin(ctx context.Context, event CheckinEvent) error {
func (svc *MDMService) Checkin(ctx context.Context, event CheckinEvent) ([]byte, error) {
// reject user settings at the loginwindow.
// https://github.com/micromdm/micromdm/pull/379
if event.Command.MessageType == "UserAuthenticate" {
return &rejectUserAuth{}
return nil, &rejectUserAuth{}
}
msg, err := MarshalCheckinEvent(&event)
if err != nil {
return errors.Wrap(err, "marshal checkin event")
return nil, errors.Wrap(err, "marshal checkin event")
}
topic, err := topicFromMessage(event.Command.MessageType)
if err != nil {
return errors.Wrap(err, "get checkin topic from message")
return nil, errors.Wrap(err, "get checkin topic from message")
}
if topic == AuthenticateTopic {
if err := svc.queue.Clear(ctx, event); err != nil {
return errors.Wrap(err, "clearing queue on enrollment attempt")
return nil, errors.Wrap(err, "clearing queue on enrollment attempt")
}
}
err = svc.pub.Publish(ctx, topic, msg)
return errors.Wrapf(err, "publish checkin on topic: %s", topic)
return nil, errors.Wrapf(err, "publish checkin on topic: %s", topic)
}
func topicFromMessage(messageType string) (string, error) {
@@ -74,10 +74,12 @@ type checkinRequest struct {
}
type checkinResponse struct {
Err error `plist:"error,omitempty"`
Payload []byte
Err error `plist:"error,omitempty"`
}
func (r checkinResponse) Failed() error { return r.Err }
func (r checkinResponse) Response() []byte { return r.Payload }
func (r checkinResponse) Failed() error { return r.Err }
func decodeCheckinRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var cmd CheckinCommand
@@ -106,7 +108,7 @@ func decodeCheckinRequest(ctx context.Context, r *http.Request) (interface{}, er
func MakeCheckinEndpoint(svc Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(checkinRequest)
err := svc.Checkin(ctx, req.Event)
return checkinResponse{Err: err}, nil
payload, err := svc.Checkin(ctx, req.Event)
return checkinResponse{Payload: payload, Err: err}, nil
}
}

View File

@@ -6,8 +6,15 @@ import (
"github.com/micromdm/micromdm/platform/pubsub"
)
// Service describes the core MDM protocol interactions with clients.
type Service interface {
Checkin(ctx context.Context, event CheckinEvent) error
// Checkin is called for all checkin messages (such as Authenticate
// and TokenUpdate). Checkin messages return a response to the
// client in payload.
Checkin(ctx context.Context, event CheckinEvent) (payload []byte, err error)
// Acknowledge is called when a client reports a command result and
// fetches the next command. Acknowledge calls return a command in
// payload.
Acknowledge(ctx context.Context, event AcknowledgeEvent) (payload []byte, err error)
}

View File

@@ -85,32 +85,32 @@ func (mw *udidCertAuthMiddleware) Acknowledge(ctx context.Context, req mdm.Ackno
return mw.next.Acknowledge(ctx, req)
}
func (mw *udidCertAuthMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) error {
func (mw *udidCertAuthMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) ([]byte, error) {
// only validate device enrollments, user enrollments should be separate.
if req.Command.EnrollmentID != "" {
return mw.next.Checkin(ctx, req)
}
devcert, err := mdm.DeviceCertificateFromContext(ctx)
if err != nil {
return errors.Wrap(err, "error retrieving device certificate")
return nil, errors.Wrap(err, "error retrieving device certificate")
}
switch req.Command.MessageType {
case "Authenticate":
// unconditionally save the cert hash on Authenticate message
if err := mw.store.SaveUDIDCertHash([]byte(req.Command.UDID), hashCertRaw(devcert.Raw)); err != nil {
return err
return nil, err
}
return mw.next.Checkin(ctx, req)
case "TokenUpdate", "CheckOut":
matched, err := mw.validateUDIDCertAuth([]byte(req.Command.UDID), hashCertRaw(devcert.Raw))
if err != nil {
return err
return nil, err
}
if !matched && !mw.warnOnly {
return errors.New("device certifcate UDID mismatch")
return nil, errors.New("device certifcate UDID mismatch")
}
return mw.next.Checkin(ctx, req)
default:
return errors.Errorf("unknown checkin message type %s", req.Command.MessageType)
return nil, errors.Errorf("unknown checkin message type %s", req.Command.MessageType)
}
}

View File

@@ -58,7 +58,7 @@ func (mw removeMiddleware) Acknowledge(ctx context.Context, req mdm.AcknowledgeE
return mw.next.Acknowledge(ctx, req)
}
func (mw removeMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) error {
func (mw removeMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) ([]byte, error) {
return mw.next.Checkin(ctx, req)
}

View File

@@ -96,25 +96,25 @@ func (mw *verifyCertificateMiddleware) Acknowledge(ctx context.Context, req mdm.
return mw.next.Acknowledge(ctx, req)
}
func (mw *verifyCertificateMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) error {
func (mw *verifyCertificateMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) ([]byte, error) {
devcert, err := mdm.DeviceCertificateFromContext(ctx)
if err != nil {
return errors.Wrap(err, "error retrieving device certificate")
return nil, errors.Wrap(err, "error retrieving device certificate")
}
hasCN, err := mw.store.HasCN(devcert.Subject.CommonName, 0, devcert, false)
if err != nil {
return errors.Wrap(err, "error checking device certificate")
return nil, errors.Wrap(err, "error checking device certificate")
}
unauth_err := errors.New("unauthorized client")
if !hasCN && !mw.validateSCEPIssuer {
_ = level.Info(mw.logger).Log("err", unauth_err, "issuer", devcert.Issuer.String(), "expiration", devcert.NotAfter)
return unauth_err
return nil, unauth_err
}
if !hasCN && mw.validateSCEPIssuer {
err := mw.verifyIssuer(devcert)
if err != nil {
_ = level.Info(mw.logger).Log("err", err, "issuer", devcert.Issuer.String(), "expiration", devcert.NotAfter)
return unauth_err
return nil, unauth_err
}
}
return mw.next.Checkin(ctx, req)