diff --git a/mdm/checkin.go b/mdm/checkin.go index b11adbb0..7e9f945d 100644 --- a/mdm/checkin.go +++ b/mdm/checkin.go @@ -10,31 +10,31 @@ import ( "github.com/pkg/errors" ) -func (svc *MDMService) Checkin(ctx context.Context, event CheckinEvent) error { +func (svc *MDMService) Checkin(ctx context.Context, event CheckinEvent) ([]byte, error) { // reject user settings at the loginwindow. // https://github.com/micromdm/micromdm/pull/379 if event.Command.MessageType == "UserAuthenticate" { - return &rejectUserAuth{} + return nil, &rejectUserAuth{} } msg, err := MarshalCheckinEvent(&event) if err != nil { - return errors.Wrap(err, "marshal checkin event") + return nil, errors.Wrap(err, "marshal checkin event") } topic, err := topicFromMessage(event.Command.MessageType) if err != nil { - return errors.Wrap(err, "get checkin topic from message") + return nil, errors.Wrap(err, "get checkin topic from message") } if topic == AuthenticateTopic { if err := svc.queue.Clear(ctx, event); err != nil { - return errors.Wrap(err, "clearing queue on enrollment attempt") + return nil, errors.Wrap(err, "clearing queue on enrollment attempt") } } err = svc.pub.Publish(ctx, topic, msg) - return errors.Wrapf(err, "publish checkin on topic: %s", topic) + return nil, errors.Wrapf(err, "publish checkin on topic: %s", topic) } func topicFromMessage(messageType string) (string, error) { @@ -74,10 +74,12 @@ type checkinRequest struct { } type checkinResponse struct { - Err error `plist:"error,omitempty"` + Payload []byte + Err error `plist:"error,omitempty"` } -func (r checkinResponse) Failed() error { return r.Err } +func (r checkinResponse) Response() []byte { return r.Payload } +func (r checkinResponse) Failed() error { return r.Err } func decodeCheckinRequest(ctx context.Context, r *http.Request) (interface{}, error) { var cmd CheckinCommand @@ -106,7 +108,7 @@ func decodeCheckinRequest(ctx context.Context, r *http.Request) (interface{}, er func MakeCheckinEndpoint(svc Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(checkinRequest) - err := svc.Checkin(ctx, req.Event) - return checkinResponse{Err: err}, nil + payload, err := svc.Checkin(ctx, req.Event) + return checkinResponse{Payload: payload, Err: err}, nil } } diff --git a/mdm/service.go b/mdm/service.go index 64958ebd..c9576c73 100644 --- a/mdm/service.go +++ b/mdm/service.go @@ -6,8 +6,15 @@ import ( "github.com/micromdm/micromdm/platform/pubsub" ) +// Service describes the core MDM protocol interactions with clients. type Service interface { - Checkin(ctx context.Context, event CheckinEvent) error + // Checkin is called for all checkin messages (such as Authenticate + // and TokenUpdate). Checkin messages return a response to the + // client in payload. + Checkin(ctx context.Context, event CheckinEvent) (payload []byte, err error) + // Acknowledge is called when a client reports a command result and + // fetches the next command. Acknowledge calls return a command in + // payload. Acknowledge(ctx context.Context, event AcknowledgeEvent) (payload []byte, err error) } diff --git a/platform/device/udidauth.go b/platform/device/udidauth.go index a0dc82b4..bdc17aa6 100644 --- a/platform/device/udidauth.go +++ b/platform/device/udidauth.go @@ -85,32 +85,32 @@ func (mw *udidCertAuthMiddleware) Acknowledge(ctx context.Context, req mdm.Ackno return mw.next.Acknowledge(ctx, req) } -func (mw *udidCertAuthMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) error { +func (mw *udidCertAuthMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) ([]byte, error) { // only validate device enrollments, user enrollments should be separate. if req.Command.EnrollmentID != "" { return mw.next.Checkin(ctx, req) } devcert, err := mdm.DeviceCertificateFromContext(ctx) if err != nil { - return errors.Wrap(err, "error retrieving device certificate") + return nil, errors.Wrap(err, "error retrieving device certificate") } switch req.Command.MessageType { case "Authenticate": // unconditionally save the cert hash on Authenticate message if err := mw.store.SaveUDIDCertHash([]byte(req.Command.UDID), hashCertRaw(devcert.Raw)); err != nil { - return err + return nil, err } return mw.next.Checkin(ctx, req) case "TokenUpdate", "CheckOut": matched, err := mw.validateUDIDCertAuth([]byte(req.Command.UDID), hashCertRaw(devcert.Raw)) if err != nil { - return err + return nil, err } if !matched && !mw.warnOnly { - return errors.New("device certifcate UDID mismatch") + return nil, errors.New("device certifcate UDID mismatch") } return mw.next.Checkin(ctx, req) default: - return errors.Errorf("unknown checkin message type %s", req.Command.MessageType) + return nil, errors.Errorf("unknown checkin message type %s", req.Command.MessageType) } } diff --git a/platform/remove/remove.go b/platform/remove/remove.go index b2b6e78a..9775584a 100644 --- a/platform/remove/remove.go +++ b/platform/remove/remove.go @@ -58,7 +58,7 @@ func (mw removeMiddleware) Acknowledge(ctx context.Context, req mdm.AcknowledgeE return mw.next.Acknowledge(ctx, req) } -func (mw removeMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) error { +func (mw removeMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) ([]byte, error) { return mw.next.Checkin(ctx, req) } diff --git a/server/devicecert.go b/server/devicecert.go index 3e64b6ee..61a7ba62 100644 --- a/server/devicecert.go +++ b/server/devicecert.go @@ -96,25 +96,25 @@ func (mw *verifyCertificateMiddleware) Acknowledge(ctx context.Context, req mdm. return mw.next.Acknowledge(ctx, req) } -func (mw *verifyCertificateMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) error { +func (mw *verifyCertificateMiddleware) Checkin(ctx context.Context, req mdm.CheckinEvent) ([]byte, error) { devcert, err := mdm.DeviceCertificateFromContext(ctx) if err != nil { - return errors.Wrap(err, "error retrieving device certificate") + return nil, errors.Wrap(err, "error retrieving device certificate") } hasCN, err := mw.store.HasCN(devcert.Subject.CommonName, 0, devcert, false) if err != nil { - return errors.Wrap(err, "error checking device certificate") + return nil, errors.Wrap(err, "error checking device certificate") } unauth_err := errors.New("unauthorized client") if !hasCN && !mw.validateSCEPIssuer { _ = level.Info(mw.logger).Log("err", unauth_err, "issuer", devcert.Issuer.String(), "expiration", devcert.NotAfter) - return unauth_err + return nil, unauth_err } if !hasCN && mw.validateSCEPIssuer { err := mw.verifyIssuer(devcert) if err != nil { _ = level.Info(mw.logger).Log("err", err, "issuer", devcert.Issuer.String(), "expiration", devcert.NotAfter) - return unauth_err + return nil, unauth_err } } return mw.next.Checkin(ctx, req)