commit 05ed593a4bc13111aa242851316d523046918500 Author: Reese Norris Date: Thu Apr 4 19:40:43 2024 -0700 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7986524 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.idea/* +.DS_Store +*.db +.envrc +*.sqlite \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ca09da7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +FROM golang:1.22.1-bullseye as build + +ENV CGO_ENABLED=1 +ENV GOOS=linux + +WORKDIR /build +COPY . . + +RUN go mod download +RUN go mod verify + +RUN go build -ldflags='-s -w -extldflags "-static"' -o main . + +FROM gcr.io/distroless/static-debian11 + +WORKDIR /openfsd + +COPY --from=build /build/main . + +CMD ["./main"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..d68a672 --- /dev/null +++ b/README.md @@ -0,0 +1,140 @@ +# openfsd + +**openfsd** is an experimental FSD server, implementing the `Vatsim2022` protocol revision ([Velocity](https://forums.vatsim.net/topic/32619-vatsim-announces-velocity-release-date-and-rollout-plan/)) for pilot clients. + +## About + +FSD (Flight Sim Daemon) is the software/protocol responsible for connecting home flight simulator clients to a single, shared multiplayer world on hobbyist networks such as [VATSIM](https://vatsim.net/docs/about/about-vatsim). +FSD was originally written in the late 90's by [Marty Bochane](https://github.com/kuroneko/fsd) for [SATCO](https://web.archive.org/web/20000619145015/http://www.satco.org/), later to be forked and taken closed-source by VATSIM in 2001. As of April 2024, FSD is still used to facilitate over 140,000 active members connecting their flight simulators to the [network](https://map.vatsim.net). + +## Building +``` +go mod download +go build -o fsd . +``` + +A default admin user will be printed to stdout on first startup. A simple web interface can be accessed using these credentials to add more users at `/dashboard` + +The server is configured via environment variables: + +| Variable Name | Default Value | Description | +|-----------------|---------------|------------------------------------------------------------| +| `FSD_ADDR` | 0.0.0.0:6809 | FSD listen address | +| `HTTP_ADDR` | 0.0.0.0:9086 | HTTP listen address | +| `HTTPS_ENABLED` | false | Enable HTTPS | +| `TLS_CERT_FILE` | | TLS certificate file path | +| `TLS_KEY_FILE` | | TLS key file path | +| `DATABASE_FILE` | ./fsd.db | SQLite database file path | +| `MOTD` | openfsd | Message to send on FSD client login (line feeds supported) | + +## Overview + +Various clients such as [vPilot](https://vpilot.rosscarlson.dev/), [xPilot](https://docs.xpilot-project.org/) and [swift](https://swift-project.org/) are used to connect to VATSIM FSD servers. + +This project does not currently support air traffic control clients. + +At its core, FSD is a message forwarder. Other than a few direct client/server transactions, the main purpose of the FSD server is to facilitate the passing of messages between flight simulator clients. Albeit, none of this is P2P—all messages are forwarded via a centralized server. + +### Protocol + +The protocol is entirely plaintext, and can be easily sniffed using packet capture tools such as Wireshark. FSD conventionally listens on TCP port 6809. You can use telnet to interact with an FSD server, try it out: + +``` +telnet fsd.connect.vatsim.net 6809 +``` + +The few mainstream implementations of FSD have their own nuances within their respective protocols. This project attempts to replicate VATSIM-specific behavior (which differs from other implementations such as [IVAO](https://www.ivao.aero/)). + +Throughout this project, I often refer to FSD messages as "packets" or "lines." The term "packet" is referring to the application-layer implementation of FSD, and has nothing to do with any transport or IP layers. + +FSD packets are plaintext MS-DOS-style lines that end with (CR/LF) characters. +Each line starts with a packet identifier, which is either 1 or 3 characters in length. Each field within the packet is delimited by a colon `:` character. All numerical values are represented as base-10 (or rarely base-16) encoded ASCII strings. There are no "binary" encodings in the protocol. + +Clients are addressed by their plaintext aviation callsigns, e.g. N7938C. FSD packets generally have "From" and "To" fields, where the "From" field can be thought of as a source address (or source callsign, in this case) and the "To" field being a recipient address/callsign. Depending on the packet type, the "To" field can be either a single client address (representing a "Direct Message" of sorts), or another special identifier representing several clients. + +For example, the following is a verbatim "Server Identification" message represented as a go string: + +```go +"$DISERVER:CLIENT:VATSIM FSD V3.43:d95f57db664f\r\n" +``` + +Hex representation: + +``` + 00000000 24 44 49 53 45 52 56 45 52 3a 43 4c 49 45 4e 54 $DISERVE R:CLIENT + 00000010 3a 56 41 54 53 49 4d 20 46 53 44 20 56 33 2e 34 :VATSIM FSD V3.4 + 00000020 33 3a 64 39 35 66 35 37 64 62 36 36 34 66 0d 0a 3:d95f57 db664f.. +``` + +> See the `protocol` package for packet parsers and serializers. + +### Connection Flow + +1. **Server Identification packet:** Once the TCP connection has been established, the server identifies itself with a "Server Identification" message, as seen in the example above. The packet identifier is `$DI` (Server Identification). The first field is the "From" field: `SERVER`. The second field is the "To" field: `CLIENT` (`CLIENT` is used here as a generic recipient callsign because the client hasn't announced itself yet.) The third field is the server version identifier. The fourth field is a random hexadecimal string used later for 'VatsimAuth' client verification (see fsd/vatsimauth) + + +2. **Login Token request:** Up until 2022, user's passwords were sent in plaintext over FSD (yikes.) Now, login tokens are obtained over HTTPS via `auth.vatsim.net/api/fsd-jwt` (implemented in http_server.go.) Now, logging into FSD, this token is sent in place where the old password used to be... in plaintext (still yikes.) + + +3. **Client Identification packet:** The client identifies itself in its first message: + +```go +"$IDN12345:SERVER:88e4:vPilot:3:8:1000000:5816673295:35df255c\r\n" +``` + +This includes their callsign (N12345 in this case), client information, "cert ID" aka user ID, an identifier based on the client's MAC address, and another random hexadecimal string used for 'VatsimAuth' + +4. **Add Pilot packet:** The client sends its second message containing more login information: + +```go +"#APN12345:SERVER:1000000::1:101:2:John Doe\r\n" +``` + +This includes the source callsign again, the cert ID again, the login token previously obtained, the user's requested network rating, the protocol revision they're using, their simulator type, and their real name. + + +5. **Server MOTD packet:** If all goes well for the login, the server sends a Text Message packet with the message of the day. This serves as the "login successful" message. If the login does not go well, the server will send an error packet and close the connection. + + +```go +"#TMserver:N12345:Welcome to VATSIM! Need help getting started? Visit https://vats.im/plc for excellent resources.\r\n" +``` + +VATSIM's FSD implementation sends a lowercase '`server`' identifier for this specific message. I have no idea why this is the case. Inconsistencies like this are common across FSD. + +### Post-Login + +After the client has logged in, it's generally free to send whatever it wants. This project implements logic for the following packets: + +### Login-only packets + +| Name | Identifier | Direction | Description | +|-----------------------|------------|-----------|-----------------------------------------------------------------------------------------------------------------------------------| +| Server Identification | `$DI` | S → C | Server identification | +| Client Identification | `$ID` | C → S | Client identification | +| Add Pilot | `#AP` | C ↔ S | Client login information. Broadcasted to all other clients on server with the "Token" field omitted, assuming a successful login. | + + +### Post-login packets + +| Name | Identifier | Description | +|------------------------------------|------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Auth Challenge | `$ZC` | 'VatsimAuth' challenge | +| Auth Challenge Response | `$ZR` | Response for an Auth Challenge. | +| Client Query | `$CQ` | Query another client for some information. In some cases, `SERVER` can be the recipient. There are cases when this packet type can also be used to broadcast unsolicited information to several other clients. | +| Client Query Response | `$CR` | Respond to a client query. | +| Delete Pilot | `#DP` | Sent when a client would like to disconnect. Broadcasted to all other clients on server. | +| Fast Pilot Position (Fast Type) | `^` | Contains position and velocity information for the client's airplane. Broadcasted to all clients within reasonable geographical range. Sent 5 times per second. This packet is only sent when the server signals the client with a "Send Fast" packet marked as `true`, and the velocity of the airplane is > 0. | +| Fast Pilot Position (Slow Type) | `#SL` | Contains position and velocity information for the client's airplane. Broadcasted to all clients within reasonable geographical range. Sent once every 5 seconds. This packet is only sent when the server signals the client with a "Send Fast" packet marked as `false`. | +| Fast Pilot Position (Stopped Type) | `#ST` | Contains position and velocity information for the client's airplane. Broadcasted to all clients within reasonable geographical range. Sent 5 times per second. This packet is only sent when the server signals the client with a "Send Fast" packet marked as `true`, and the velocity of the airplane is zero. | +| Normal Pilot Position | `@` | Contains position and other miscellaneous airplane-related information. Broadcasted to all clients within reasonable geographical range. Sent once every 5 seconds. This packet is repeatedly broadcasted from the client throughout the entire lifetime of the connection. | +| Send Fast | `$SF` | Sent by the server to signal whether the client should send Fast Pilot Position packets 5 times per second, rather than just the standard pilot positions once per 5 seconds. | +| Error | `$ER` | Sent by the server when it encounters an error | +| Kill | `$!!` | When sent by a privileged user, it is forwarded to the "victim" client, then the victim client is kicked from the server. | +| Ping | `$PI` | Ping | +| Pong | `$PO` | Respond to a Ping | +| Plane Info Request | `#SB` | Query another client for model matching information | +| Plane Info Response | `#SB` | Respond to a Plane Info Request | +| Text Message | `#TM` | Used to send text messages. Depending on the value of the recipient field, a text message can represent any of the following:
— Radio message: a message to be broadcasted to all airplanes on a simulated VHF frequency
— Direct message: addressed to a single client
— "Wallop" message: a message that is forwarded to all users with "Supervisor" or higher privilege)
— Broadcast message: forwarded to all connected clients. | + + diff --git a/dashboard.html b/dashboard.html new file mode 100644 index 0000000..63fc831 --- /dev/null +++ b/dashboard.html @@ -0,0 +1,197 @@ + + + + + FSD + + + + + + + +

FSD

+ +
+ +
+

Get user

+
+ + + +
+ +
+
+
+

Create user

+
+ + + + + + + +
+ +
+
+ + + \ No newline at end of file diff --git a/database_util.go b/database_util.go new file mode 100644 index 0000000..97e5151 --- /dev/null +++ b/database_util.go @@ -0,0 +1,128 @@ +package main + +import ( + "database/sql" + "errors" + "golang.org/x/crypto/bcrypt" + "time" +) + +type FSDUserRecord struct { + CID int `json:"cid"` + Password string `json:"password,omitempty"` + Rating int `json:"rating"` + RealName string `json:"real_name"` + CreationTime time.Time `json:"creation_time"` +} + +func GetUserRecord(db *sql.DB, cid int) (*FSDUserRecord, error) { + row := db.QueryRow("SELECT * FROM users WHERE cid=? LIMIT 1", cid) + + var cidRecord int + var pwd string + var rating int + var realName string + var creationTime time.Time + + if err := row.Scan(&cidRecord, &pwd, &rating, &realName, &creationTime); errors.Is(err, sql.ErrNoRows) { + return nil, err + } + + return &FSDUserRecord{ + CID: cidRecord, + Password: pwd, + Rating: rating, + RealName: realName, + CreationTime: creationTime, + }, nil +} + +func AddUserRecord(db *sql.DB, cid int, pwdHash string, rating int, realName string) (*FSDUserRecord, error) { + t := time.Now() + + res, err := db.Exec("INSERT INTO users (cid, password, rating, real_name, creation_time) VALUES(?,?,?,?,?);", cid, pwdHash, rating, realName, t) + if err != nil { + return nil, err + } + + resCID, err := res.LastInsertId() + if err != nil { + return nil, err + } + + record := FSDUserRecord{ + CID: int(resCID), + Password: "", + Rating: rating, + RealName: realName, + CreationTime: t, + } + + return &record, nil +} + +func AddUserRecordSequential(db *sql.DB, pwdHash string, rating int, realName string) (*FSDUserRecord, error) { + t := time.Now() + + res, err := db.Exec("INSERT INTO users (password, rating, real_name, creation_time) VALUES(?,?,?,?);", pwdHash, rating, realName, t) + if err != nil { + return nil, err + } + + resCID, err := res.LastInsertId() + if err != nil { + return nil, err + } + + record := FSDUserRecord{ + CID: int(resCID), + Password: "", + Rating: rating, + RealName: realName, + CreationTime: t, + } + + return &record, nil +} + +// UpdateUserRecord updates a user record. It only updates the password field if it is non-empty. +func UpdateUserRecord(db *sql.DB, record *FSDUserRecord) error { + var pwdHash []byte + var err error + if record.Password != "" { + pwdHash, err = bcrypt.GenerateFromPassword([]byte(record.Password), 10) + if err != nil { + return err + } + } + + if pwdHash == nil { + _, err := db.Exec("UPDATE users SET rating=?, real_name=? WHERE cid=?;", record.Rating, record.RealName, record.CID) + if err != nil { + return err + } + } else { + _, err := db.Exec("UPDATE users SET password=?, rating=?, real_name=? WHERE cid=?;", string(pwdHash), record.Rating, record.RealName, record.CID) + if err != nil { + return err + } + } + + return nil +} + +func IsUsersTableEmpty(db *sql.DB) (bool, error) { + row := db.QueryRow("SELECT * FROM users LIMIT 1;") + + var cid int + var pwd string + var rating int + var realName string + var creationTime time.Time + + if err := row.Scan(&cid, &pwd, &rating, &realName, &creationTime); errors.Is(err, sql.ErrNoRows) { + return true, nil + } else { + return false, err + } +} diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..633fd52 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,16 @@ +services: + fsd: + build: + context: "." + ports: + - "6809:6809/tcp" + - "443:443/tcp" + volumes: + - /opt/fsd/fsd.db:/fsd/fsd.db + environment: + FSD_ADDR: "0.0.0.0:6809" + HTTP_ADDR: "0.0.0.0:9086" + HTTPS_ENABLED: false + DATABASE_FILE: "/fsd/fsd.db" + MOTD: "openfsd" + restart: always diff --git a/fsd_client.go b/fsd_client.go new file mode 100644 index 0000000..ced5b6a --- /dev/null +++ b/fsd_client.go @@ -0,0 +1,418 @@ +package main + +import ( + "bufio" + "context" + "errors" + "github.com/golang-jwt/jwt/v5" + "github.com/renorris/openfsd/protocol" + "github.com/renorris/openfsd/vatsimauth" + "log" + "net" + "strconv" + "strings" + "time" +) + +type FSDClient struct { + Conn *net.TCPConn + Reader *bufio.Reader + + AuthVerify *vatsimauth.VatsimAuth // Auth state to verify client's auth responses + PendingAuthVerifyChallenge string // Store the pending challenge sent to the client + AuthSelf *vatsimauth.VatsimAuth // Auth state for interrogating client + + Callsign string + CID int + NetworkRating int + SimulatorType int + RealName string + CurrentGeohash string + SendFastEnabled bool + + Kill chan string // Signal to disconnect this client + Mailbox chan string // Incoming messages +} + +// EventLoop runs the main event loop for an FSD client. +// All clients that reach this stage are logged in +func EventLoop(client *FSDClient) { + + // Reader goroutine + packetsRead := make(chan string) + readerCtx, cancelReader := context.WithCancel(context.Background()) + go func(ctx context.Context) { + for { + // Reset the deadline + err := client.Conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + close(packetsRead) + return + } + + var buf []byte + buf, err = client.Reader.ReadSlice('\n') + if err != nil { + close(packetsRead) + return + } + + packet := string(buf) + + // Validate delimiter + if len(packet) < 2 || string(packet[len(packet)-2:]) != "\r\n" { + close(packetsRead) + return + } + + // Send the packet over the channel + // (also watch for context.Done) + select { + case <-ctx.Done(): + close(packetsRead) + return + case packetsRead <- packet: + } + } + }(readerCtx) + + // Defer reader cancellation + defer cancelReader() + + // Writer goroutine + packetsToWrite := make(chan string, 16) + writerClosed := make(chan struct{}) + go func() { + for { + var packet string + var ok bool + // Wait for a packet + select { + case packet, ok = <-packetsToWrite: + if !ok { + close(writerClosed) + return + } + } + + // Reset the deadline + err := client.Conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + close(writerClosed) + // Exhaust packetsToWrite + for { + select { + case _, ok := <-packetsToWrite: + if !ok { + return + } + } + } + } + + // Attempt to write the packet + _, err = client.Conn.Write([]byte(packet)) + if err != nil { + close(writerClosed) + // Exhaust packetsToWrite + for { + select { + case _, ok := <-packetsToWrite: + if !ok { + return + } + } + } + } + } + }() + + // Wait max 100 milliseconds for writer to flush before continuing the connection shutdown + defer func() { + timer := time.NewTimer(100 * time.Millisecond) + select { + case <-writerClosed: + case <-timer.C: + } + }() + + // Defer writer close + defer close(packetsToWrite) + + // Defer "delete pilot" broadcast + defer func() { + deletePilotPDU := protocol.DeletePilotPDU{ + From: client.Callsign, + CID: client.CID, + } + + mail := NewMail(client) + mail.SetType(MailTypeBroadcastAll) + mail.AddPacket(deletePilotPDU.Serialize()) + PO.SendMail([]Mail{*mail}) + }() + + // Main loop + for { + select { + case packet, ok := <-packetsRead: + // Check if the reader closed + if !ok { + return + } + + // Find the processor for this packet + processor, err := GetProcessor(packet) + if err != nil { + packetsToWrite <- protocol.NewGenericFSDError(protocol.SyntaxError).Serialize() + return + } + + result := processor(client, packet) + + // Send replies to the client + for _, replyPacket := range result.Replies { + packetsToWrite <- replyPacket + } + + // Send mail + PO.SendMail(result.Mail) + + // Disconnect the client if flagged + if result.ShouldDisconnect { + return + } + case <-writerClosed: + return + case mailPacket := <-client.Mailbox: + packetsToWrite <- mailPacket + case s, ok := <-client.Kill: + if ok { + select { + case packetsToWrite <- s: + default: + } + } + return + } + } +} + +func sendSyntaxError(conn *net.TCPConn) { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.SyntaxError).Serialize())) +} + +func HandleConnection(conn *net.TCPConn) { + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Error closing connection for %s\n%s", conn.RemoteAddr().String(), err.Error()) + } + }() + + initChallenge, err := vatsimauth.GenerateChallenge() + if err != nil { + log.Printf("Error generating challenge string:\n%s", err.Error()) + return + } + + serverIdentPDU := protocol.ServerIdentificationPDU{ + From: protocol.ServerCallsign, + To: "CLIENT", + Version: "openfsd", + InitialChallenge: initChallenge, + } + serverIdentPacket := serverIdentPDU.Serialize() + + // The client has 2 seconds to log in + if err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + return + } + if err = conn.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil { + return + } + + _, err = conn.Write([]byte(serverIdentPacket)) + if err != nil { + return + } + + reader := bufio.NewReaderSize(conn, 1024) + var buf []byte + + buf, err = reader.ReadSlice('\n') + if err != nil { + sendSyntaxError(conn) + return + } + packet := string(buf) + + // Validate delimiter + if len(packet) < 2 || string(packet[len(packet)-2:]) != "\r\n" { + sendSyntaxError(conn) + return + } + + clientIdentPDU, err := protocol.ParseClientIdentificationPDU(packet) + if err != nil { + var fsdError *protocol.FSDError + if errors.As(err, &fsdError) { + conn.Write([]byte(fsdError.Serialize())) + } + return + } + + buf, err = reader.ReadSlice('\n') + if err != nil { + sendSyntaxError(conn) + return + } + packet = string(buf) + + // Validate delimiter + if len(packet) < 2 || string(packet[len(packet)-2:]) != "\r\n" { + sendSyntaxError(conn) + return + } + + addPilotPDU, err := protocol.ParseAddPilotPDU(packet) + if err != nil { + var fsdError *protocol.FSDError + if errors.As(err, &fsdError) { + conn.Write([]byte(fsdError.Serialize())) + } + return + } + + // Verify protocol revision + if addPilotPDU.ProtocolRevision != protocol.ProtoRevisionVatsim2022 { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.InvalidProtocolRevisionError).Serialize())) + return + } + + // Verify if we support this client + _, ok := vatsimauth.Keys[clientIdentPDU.ClientID] + if !ok { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.UnauthorizedSoftwareError).Serialize())) + return + } + + // Verify fields in login PDUs + if clientIdentPDU.From != addPilotPDU.From || + clientIdentPDU.CID != addPilotPDU.CID { + sendSyntaxError(conn) + return + } + + // Validate token signature + claims := jwt.MapClaims{} + token, err := jwt.ParseWithClaims(addPilotPDU.Token, &claims, func(token *jwt.Token) (interface{}, error) { + return JWTKey, nil + }) + if err != nil { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.InvalidLogonError).Serialize())) + return + } + + // Check for expiry + exp, err := token.Claims.GetExpirationTime() + if err != nil { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.InvalidLogonError).Serialize())) + return + } + if time.Now().After(exp.Time) { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.InvalidLogonError).Serialize())) + return + } + + // Verify claimed CID + claimedCID, err := claims.GetSubject() + if err != nil { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.InvalidLogonError).Serialize())) + return + } + + cidInt, err := strconv.Atoi(claimedCID) + if err != nil { + sendSyntaxError(conn) + return + } + + if cidInt != clientIdentPDU.CID { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.InvalidLogonError).Serialize())) + return + } + + // Verify controller rating + claimedRating, ok := claims["controller_rating"].(float64) + if !ok { + sendSyntaxError(conn) + return + } + + if addPilotPDU.NetworkRating > int(claimedRating) { + conn.Write([]byte(protocol.NewGenericFSDError(protocol.RequestedLevelTooHighError).Serialize())) + return + } + + // Configure client + fsdClient := FSDClient{ + Conn: conn, + Reader: reader, + AuthVerify: &vatsimauth.VatsimAuth{}, + PendingAuthVerifyChallenge: "", + AuthSelf: &vatsimauth.VatsimAuth{}, + Callsign: clientIdentPDU.From, + CID: clientIdentPDU.CID, + NetworkRating: int(claimedRating), + SimulatorType: addPilotPDU.SimulatorType, + RealName: addPilotPDU.RealName, + CurrentGeohash: "", + SendFastEnabled: false, + Kill: make(chan string, 1), + Mailbox: make(chan string, 16), + } + + // Register callsign to the post office. End the connection if callsign already exists + { + err := PO.RegisterCallsign(clientIdentPDU.From, &fsdClient) + if err != nil { + if errors.Is(err, CallsignAlreadyRegisteredError) { + pdu := protocol.NewGenericFSDError(protocol.CallsignInUseError) + conn.Write([]byte(pdu.Serialize())) + } + return + } + } + defer PO.DeregisterCallsign(clientIdentPDU.From) + + // Configure vatsim auth states + fsdClient.AuthSelf = vatsimauth.NewVatsimAuth(clientIdentPDU.ClientID, vatsimauth.Keys[clientIdentPDU.ClientID]) + fsdClient.AuthSelf.SetInitialChallenge(clientIdentPDU.InitialChallenge) + fsdClient.AuthVerify = vatsimauth.NewVatsimAuth(clientIdentPDU.ClientID, vatsimauth.Keys[clientIdentPDU.ClientID]) + fsdClient.AuthVerify.SetInitialChallenge(initChallenge) + + // Broadcast AddPilot packet to network + addPilotPDU.Token = "" + mail := NewMail(&fsdClient) + mail.SetType(MailTypeBroadcastAll) + mail.AddPacket(addPilotPDU.Serialize()) + PO.SendMail([]Mail{*mail}) + + // Send MOTD + lines := strings.Split(SC.MOTD, "\n") + for _, line := range lines { + pdu := protocol.TextMessagePDU{ + From: protocol.ServerCallsign, + To: clientIdentPDU.From, + Message: line, + } + _, err := conn.Write([]byte(pdu.Serialize())) + if err != nil { + return + } + } + + // Start the event loop + EventLoop(&fsdClient) +} diff --git a/fsd_client_logic_test.go b/fsd_client_logic_test.go new file mode 100644 index 0000000..5b1e836 --- /dev/null +++ b/fsd_client_logic_test.go @@ -0,0 +1,984 @@ +package main + +import ( + "bufio" + "context" + "github.com/renorris/openfsd/protocol" + "github.com/stretchr/testify/assert" + "net" + "os" + "strings" + "sync" + "testing" + "time" +) + +type clientStruct struct { + callsign string + cid int + password string + clientID uint16 + clientName string + majorVersion int + minorVersion int + sysUID int + initialChallenge string + networkRating int + protocolRevsion int + simulatorType int + realName string + + preliminaryTestPackets []protocol.PDU // Packets to send after logging in, but before the next client logs in + preliminaryWantPackets []protocol.PDU // Expected packets to receive before the next client logs in + testPackets []protocol.PDU // Packets to send in normal post-login state + wantPackets []protocol.PDU // Expected packets to receive +} + +// TestFSDClientLogic focuses on post-login logic +func TestFSDClientLogic(t *testing.T) { + SC = &ServerConfig{ + FsdListenAddr: "localhost:6809", + HttpListenAddr: "localhost:9086", + HttpsEnabled: false, + DatabaseFile: "./test.db", + MOTD: "openfsd", + } + + // Delete any existing database file so a new one is created + os.Remove(SC.DatabaseFile) + defer os.Remove(SC.DatabaseFile) + + // Run configuration helpers and start the FSD server + configureJwt() + configurePostOffice() + configureProtocolValidator() + configureDatabase() + + // Add test users + addUserToDatabase(t, 1000000, "12345", protocol.NetworkRatingOBS) + addUserToDatabase(t, 1000001, "12345", protocol.NetworkRatingOBS) + addUserToDatabase(t, 1000002, "12345", protocol.NetworkRatingSUP) + + // Start FSD server + fsdCtx, cancelFsd := context.WithCancel(context.Background()) + go StartFSDServer(fsdCtx) + defer cancelFsd() + + // Start http server + httpCtx, cancelHttp := context.WithCancel(context.Background()) + go StartHttpServer(httpCtx) + defer cancelHttp() + time.Sleep(50 * time.Millisecond) + + tests := []struct { + testName string + clients []clientStruct + }{ + { + testName: "Ping ($PI) request", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.PingPDU{ + From: "N123", + To: protocol.ServerCallsign, + Timestamp: "1234567890", + }, + }, + wantPackets: []protocol.PDU{ + &protocol.PongPDU{ + From: protocol.ServerCallsign, + To: "N123", + Timestamp: "1234567890", + }, + }, + }, + }, + }, + { + testName: "IP request ($CQCLIENT:SERVER:IP)", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.ClientQueryPDU{ + From: "N123", + To: protocol.ServerCallsign, + QueryType: protocol.ClientQueryPublicIP, + Payload: "", + }, + }, + wantPackets: []protocol.PDU{ + &protocol.ClientQueryResponsePDU{ + From: protocol.ServerCallsign, + To: "N123", + QueryType: protocol.ClientQueryPublicIP, + Payload: "127.0.0.1", + }, + }, + }, + }, + }, + { + testName: "Fast pilot position broadcast", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{ + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N123", + SquawkCode: "1200", + NetworkRating: 1, + Lat: 45.0, + Lng: 45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + &protocol.PingPDU{ + From: "N123", + To: protocol.ServerCallsign, + Timestamp: "1234567890", + }, + }, + preliminaryWantPackets: []protocol.PDU{ + &protocol.PongPDU{ + From: protocol.ServerCallsign, + To: "N123", + Timestamp: "1234567890", + }, + }, + testPackets: []protocol.PDU{}, + wantPackets: []protocol.PDU{ + &protocol.AddPilotPDU{ + From: "N124", + To: protocol.ServerCallsign, + CID: 1000001, + Token: "", + NetworkRating: 1, + ProtocolRevision: protocol.ProtoRevisionVatsim2022, + SimulatorType: 2, + RealName: "Foo Bar", + }, + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N124", + SquawkCode: "1201", + NetworkRating: 1, + Lat: 45.0, + Lng: 45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + &protocol.FastPilotPositionPDU{ + Type: protocol.FastPilotPositionTypeFast, + From: "N124", + Lat: 45, + Lng: 45, + AltitudeTrue: 10, + AltitudeAgl: 10, + Pitch: 0, + Heading: 0, + Bank: 0, + PositionalVelocityVector: protocol.VelocityVector{ + X: 10, + Y: 10, + Z: 10, + }, + RotationalVelocityVector: protocol.VelocityVector{ + X: 10, + Y: 10, + Z: 10, + }, + NoseGearAngle: 15, + }, + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N124", + SquawkCode: "1201", + NetworkRating: 1, + Lat: 45.0, + Lng: 45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + }, + }, + { + callsign: "N124", + cid: 1000001, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "Foo Bar", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N124", + SquawkCode: "1201", + NetworkRating: 1, + Lat: 45.0, + Lng: 45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + &protocol.FastPilotPositionPDU{ + Type: protocol.FastPilotPositionTypeFast, + From: "N124", + Lat: 45, + Lng: 45, + AltitudeTrue: 10, + AltitudeAgl: 10, + Pitch: 0, + Heading: 0, + Bank: 0, + PositionalVelocityVector: protocol.VelocityVector{ + X: 10, + Y: 10, + Z: 10, + }, + RotationalVelocityVector: protocol.VelocityVector{ + X: 10, + Y: 10, + Z: 10, + }, + NoseGearAngle: 15, + }, + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N124", + SquawkCode: "1201", + NetworkRating: 1, + Lat: -45.0, + Lng: -45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + &protocol.FastPilotPositionPDU{ + Type: protocol.FastPilotPositionTypeFast, + From: "N124", + Lat: -45, + Lng: -45, + AltitudeTrue: 10, + AltitudeAgl: 10, + Pitch: 0, + Heading: 0, + Bank: 0, + PositionalVelocityVector: protocol.VelocityVector{ + X: 10, + Y: 10, + Z: 10, + }, + RotationalVelocityVector: protocol.VelocityVector{ + X: 10, + Y: 10, + Z: 10, + }, + NoseGearAngle: 15, + }, + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N124", + SquawkCode: "1201", + NetworkRating: 1, + Lat: 45.0, + Lng: 45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + }, + wantPackets: []protocol.PDU{}, + }, + }, + }, + { + testName: "Pilot position broadcast", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{ + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N123", + SquawkCode: "1200", + NetworkRating: 1, + Lat: 45.0, + Lng: 45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + }, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{}, + wantPackets: []protocol.PDU{ + &protocol.AddPilotPDU{ + From: "N124", + To: protocol.ServerCallsign, + CID: 1000001, + Token: "", + NetworkRating: 1, + ProtocolRevision: protocol.ProtoRevisionVatsim2022, + SimulatorType: 2, + RealName: "Foo Bar", + }, + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N124", + SquawkCode: "1200", + NetworkRating: 1, + Lat: 45.0, + Lng: 45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + }, + }, + { + callsign: "N124", + cid: 1000001, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "Foo Bar", + + preliminaryTestPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N124", + SquawkCode: "1200", + NetworkRating: 1, + Lat: 45.0, + Lng: 45.0, + TrueAltitude: 50, + PressureAltitude: 50, + GroundSpeed: 0, + Pitch: 0, + Heading: 0, + Bank: 0, + }, + }, + wantPackets: []protocol.PDU{}, + }, + }, + }, + { + testName: "Client Query Real Name ($CQN123:N124:RN)", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{}, + wantPackets: []protocol.PDU{ + &protocol.AddPilotPDU{ + From: "N124", + To: protocol.ServerCallsign, + CID: 1000001, + Token: "", + NetworkRating: 1, + ProtocolRevision: protocol.ProtoRevisionVatsim2022, + SimulatorType: 2, + RealName: "Foo Bar", + }, + &protocol.ClientQueryPDU{ + From: "N124", + To: "N123", + QueryType: protocol.ClientQueryRealName, + Payload: "", + }, + }, + }, + { + callsign: "N124", + cid: 1000001, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "Foo Bar", + + preliminaryTestPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.ClientQueryPDU{ + From: "N124", + To: "N123", + QueryType: protocol.ClientQueryRealName, + Payload: "", + }, + }, + wantPackets: []protocol.PDU{}, + }, + }, + }, + { + testName: "Client Query Real Name Response ($CRN124:N123:RN:Foo Bar)", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{}, + wantPackets: []protocol.PDU{ + &protocol.AddPilotPDU{ + From: "N124", + To: protocol.ServerCallsign, + CID: 1000001, + Token: "", + NetworkRating: 1, + ProtocolRevision: protocol.ProtoRevisionVatsim2022, + SimulatorType: 2, + RealName: "Foo Bar", + }, + &protocol.ClientQueryResponsePDU{ + From: "N124", + To: "N123", + QueryType: protocol.ClientQueryRealName, + Payload: "Foo Bar", + }, + }, + }, + { + callsign: "N124", + cid: 1000001, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "Foo Bar", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.ClientQueryResponsePDU{ + From: "N124", + To: "N123", + QueryType: protocol.ClientQueryRealName, + Payload: "Foo Bar", + }, + }, + wantPackets: []protocol.PDU{}, + }, + }, + }, + { + testName: "Authentication challenge ($ZC)", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "30984979d8caed23", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.AuthChallengePDU{ + From: "N123", + To: protocol.ServerCallsign, + Challenge: "de6acb8e", + }, + }, + wantPackets: []protocol.PDU{ + &protocol.AuthChallengeResponsePDU{ + From: protocol.ServerCallsign, + To: "N123", + ChallengeResponse: "f8ee97157f66455ed6108fccef6ccf5f", + }, + }, + }, + }, + }, + { + testName: "Kill ($!!)", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{}, + wantPackets: []protocol.PDU{ + &protocol.AddPilotPDU{ + From: "SUP", + To: protocol.ServerCallsign, + CID: 1000002, + Token: "", + NetworkRating: protocol.NetworkRatingSUP, + ProtocolRevision: protocol.ProtoRevisionVatsim2022, + SimulatorType: 2, + RealName: "Supervisor", + }, + &protocol.KillRequestPDU{ + From: "SUP", + To: "N123", + Reason: "ur banned", + }, + }, + }, + { + callsign: "SUP", + cid: 1000002, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: protocol.NetworkRatingSUP, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "Supervisor", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.KillRequestPDU{ + From: "SUP", + To: "N123", + Reason: "ur banned", + }, + }, + wantPackets: []protocol.PDU{}, + }, + }, + }, + { + testName: "Kill not allowed", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{}, + wantPackets: []protocol.PDU{ + &protocol.AddPilotPDU{ + From: "N124", + To: protocol.ServerCallsign, + CID: 1000001, + Token: "", + NetworkRating: protocol.NetworkRatingOBS, + ProtocolRevision: protocol.ProtoRevisionVatsim2022, + SimulatorType: 2, + RealName: "Foo Bar", + }, + &protocol.TextMessagePDU{ + From: "N124", + To: "N123", + Message: "hello", + }, + }, + }, + { + callsign: "N124", + cid: 1000001, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: protocol.NetworkRatingOBS, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "Foo Bar", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.KillRequestPDU{ + From: "N124", + To: "N123", + Reason: "ur banned", + }, + &protocol.TextMessagePDU{ + From: "N124", + To: "N123", + Message: "hello", + }, + }, + wantPackets: []protocol.PDU{}, + }, + }, + }, + { + testName: "Delete pilot broadcast", + clients: []clientStruct{ + { + callsign: "N123", + cid: 1000000, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: 1, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "John Doe", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{}, + wantPackets: []protocol.PDU{ + &protocol.AddPilotPDU{ + From: "N124", + To: protocol.ServerCallsign, + CID: 1000001, + Token: "", + NetworkRating: protocol.NetworkRatingOBS, + ProtocolRevision: protocol.ProtoRevisionVatsim2022, + SimulatorType: 2, + RealName: "Foo Bar", + }, + &protocol.DeletePilotPDU{ + From: "N124", + CID: 1000001, + }, + }, + }, + { + callsign: "N124", + cid: 1000001, + password: "12345", + clientID: 35044, + clientName: "vPilot", + majorVersion: 3, + minorVersion: 8, + sysUID: -99999, + initialChallenge: "abcdef", + networkRating: protocol.NetworkRatingOBS, + protocolRevsion: protocol.ProtoRevisionVatsim2022, + simulatorType: 2, + realName: "Foo Bar", + + preliminaryTestPackets: []protocol.PDU{}, + preliminaryWantPackets: []protocol.PDU{}, + testPackets: []protocol.PDU{ + &protocol.DeletePilotPDU{ + From: "N124", + CID: 1000001, + }, + }, + wantPackets: []protocol.PDU{}, + }, + }, + }, + } + + // Run tests + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + var doneWg sync.WaitGroup + // Spawn each client + for _, client := range tc.clients { + c := client + doneWg.Add(1) + + var loggedIn sync.WaitGroup + loggedIn.Add(1) + + go func() { + defer doneWg.Done() + + // Log in the client. + // Test cases are meant to be executed after the client has logged in. + // Get a JWT token first + jwtResponse := doJwtRequest(t, "http://localhost:9086/api/fsd-jwt", c.cid, c.password) + + conn, err := net.Dial("tcp4", SC.FsdListenAddr) + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + defer func() { + closeErr := conn.Close() + assert.Nil(t, closeErr) + }() + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: c.callsign, + To: protocol.ServerCallsign, + ClientID: c.clientID, + ClientName: c.clientName, + MajorVersion: c.majorVersion, + MinorVersion: c.minorVersion, + CID: c.cid, + SysUID: c.sysUID, + InitialChallenge: c.initialChallenge, + } + + addPilotPDU := protocol.AddPilotPDU{ + From: c.callsign, + To: protocol.ServerCallsign, + CID: c.cid, + Token: jwtResponse.Token, + NetworkRating: c.networkRating, + ProtocolRevision: c.protocolRevsion, + SimulatorType: c.simulatorType, + RealName: c.realName, + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + motdMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, motdMsg) + + motdReceivedPDU, err := protocol.ParseTextMessagePDU(motdMsg) + assert.Nil(t, err) + + expectedMOTD := protocol.TextMessagePDU{ + From: protocol.ServerCallsign, + To: c.callsign, + Message: SC.MOTD, + } + + assert.Equal(t, expectedMOTD.Serialize(), motdReceivedPDU.Serialize()) + + // Send post-login packets before signalling that we're done logging in + for _, packet := range c.preliminaryTestPackets { + _, writeErr := conn.Write([]byte(packet.Serialize())) + assert.Nil(t, writeErr) + } + + // Verify post-login returned packets are correct + for _, packet := range c.preliminaryWantPackets { + deadlineErr := conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, deadlineErr) + + recvPacket, recvErr := reader.ReadString('\n') + assert.Nil(t, recvErr) + + assert.Equal(t, packet.Serialize(), recvPacket) + } + + // Signal we're done logging in + loggedIn.Done() + + // Send test packets + for _, packet := range c.testPackets { + _, writeErr := conn.Write([]byte(packet.Serialize())) + assert.Nil(t, writeErr) + } + + // Verify returned packets are correct + for _, packet := range c.wantPackets { + deadlineErr := conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, deadlineErr) + + recvPacket, recvErr := reader.ReadString('\n') + assert.Nil(t, recvErr) + + assert.Equal(t, packet.Serialize(), recvPacket) + } + }() + + // Wait for the preceding client to finish logging in before spawning another + loggedIn.Wait() + } + + // Wait for all goroutines to return + doneWg.Wait() + }) + } +} diff --git a/fsd_client_login_test.go b/fsd_client_login_test.go new file mode 100644 index 0000000..a57053c --- /dev/null +++ b/fsd_client_login_test.go @@ -0,0 +1,702 @@ +package main + +import ( + "bufio" + "context" + "errors" + "github.com/golang-jwt/jwt/v5" + "github.com/renorris/openfsd/protocol" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" + "net" + "os" + "strings" + "syscall" + "testing" + "time" +) + +func addUserToDatabase(t *testing.T, cid int, password string, rating int) { + bcryptBytes, err := bcrypt.GenerateFromPassword([]byte(password), 4) + assert.Nil(t, err) + passwordHash := string(bcryptBytes) + + _, err = AddUserRecord(DB, cid, passwordHash, rating, "Test User") + assert.Nil(t, err) +} + +func TestFSDClientLogin(t *testing.T) { + // Setup config for testing environment + SC = &ServerConfig{ + FsdListenAddr: "localhost:6809", + HttpListenAddr: "localhost:9086", + HttpsEnabled: false, + DatabaseFile: "./test.db", + MOTD: "motd line 1\nmotd line 2", + } + + // Delete any existing database file so a new one is created + os.Remove(SC.DatabaseFile) + defer os.Remove(SC.DatabaseFile) + + // Run configuration helpers + configureDatabase() + configureJwt() + configurePostOffice() + configureProtocolValidator() + + // Start FSD listener + fsdCtx, cancelFsd := context.WithCancel(context.Background()) + defer cancelFsd() + go StartFSDServer(fsdCtx) + + // Start http server + httpCtx, cancelHttp := context.WithCancel(context.Background()) + defer cancelHttp() + go StartHttpServer(httpCtx) + time.Sleep(50 * time.Millisecond) + + addUserToDatabase(t, 1000000, "12345", 1) + + // Simulate a jwt login token request + jwtResponse := doJwtRequest(t, "http://localhost:9086/api/fsd-jwt", 1000000, "12345") + assert.True(t, jwtResponse.Success) + assert.NotEmpty(t, jwtResponse.Token) + assert.Empty(t, jwtResponse.ErrorMsg) + + // Test successful FSD login process + { + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: jwtResponse.Token, + NetworkRating: 1, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + expectedPDU := protocol.TextMessagePDU{ + From: protocol.ServerCallsign, + To: "N123", + Message: strings.Split(SC.MOTD, "\n")[0], + } + + assert.Equal(t, expectedPDU.Serialize(), responseMsg) + conn.Close() + } + + // Network rating too high + { + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: jwtResponse.Token, + NetworkRating: 2, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + assert.Equal(t, protocol.NewGenericFSDError(protocol.RequestedLevelTooHighError).Serialize(), responseMsg) + conn.Close() + } + + // Invalid protocol revision + { + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: jwtResponse.Token, + NetworkRating: 1, + ProtocolRevision: 100, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + assert.Equal(t, protocol.NewGenericFSDError(protocol.InvalidProtocolRevisionError).Serialize(), responseMsg) + conn.Close() + } + + // Test callsign already in use + { + conn1, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + conn2, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn1.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + err = conn2.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn1) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + reader2 := bufio.NewReader(conn2) + serverIdent2, err := reader2.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent2) + assert.True(t, strings.HasPrefix(serverIdent2, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: jwtResponse.Token, + NetworkRating: 1, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn1.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn1.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg1, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + time.Sleep(50 * time.Millisecond) + + _, err = conn2.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn2.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg2, err := reader2.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + expectedPDU1 := protocol.TextMessagePDU{ + From: protocol.ServerCallsign, + To: "N123", + Message: strings.Split(SC.MOTD, "\n")[0], + } + + assert.Equal(t, expectedPDU1.Serialize(), responseMsg1) + assert.Equal(t, protocol.NewGenericFSDError(protocol.CallsignInUseError).Serialize(), responseMsg2) + + conn1.Close() + conn2.Close() + } + + // Test packet oversize (> 1024bytes) + { + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + lotsOfNines := make([]byte, 4096) + for i := 0; i < 4096; i++ { + lotsOfNines[i] = '9' + } + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: string(lotsOfNines), // <-- Make this field 4096 bytes long + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + expectedPDU := protocol.NewGenericFSDError(protocol.SyntaxError) + assert.Equal(t, expectedPDU.Serialize(), responseMsg) + + _, err = reader.ReadString('\n') + assert.NotNil(t, err) + var opError *net.OpError + if errors.As(err, &opError) { + assert.ErrorIs(t, opError.Err, syscall.ECONNRESET) + } else { + assert.Fail(t, "wrong error") + } + } + + // Test incorrect packet sequence + { + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: jwtResponse.Token, + NetworkRating: 1, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + expectedPDU := protocol.NewGenericFSDError(protocol.SyntaxError) + + assert.Equal(t, expectedPDU.Serialize(), responseMsg) + conn.Close() + } + + // Test invalid CID + { + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 9999999, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 9999999, + Token: jwtResponse.Token, + NetworkRating: 1, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + expectedPDU := protocol.NewGenericFSDError(protocol.InvalidLogonError) + + assert.Equal(t, expectedPDU.Serialize(), responseMsg) + conn.Close() + } + + // Test jibberish token + { + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: "garbage", + NetworkRating: 1, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + expectedPDU := protocol.NewGenericFSDError(protocol.SyntaxError) + + assert.Equal(t, expectedPDU.Serialize(), responseMsg) + conn.Close() + } + + // Test token with invalid signature + { + // Fake jwt + token := jwt.NewWithClaims(jwt.SigningMethodHS256, CustomClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "https://auth.vatsim.net/api/fsd-jwt", + Subject: "1000000", + Audience: []string{"fsd-live"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(420 * time.Second)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-120 * time.Second)), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: "bogus", + }, + ControllerRating: 0, + PilotRating: 0, + }) + + tokenString, err := token.SignedString([]byte("garbage key")) + assert.Nil(t, err) + + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: tokenString, + NetworkRating: 1, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + expectedPDU := protocol.NewGenericFSDError(protocol.InvalidLogonError) + + assert.Equal(t, expectedPDU.Serialize(), responseMsg) + conn.Close() + } + + // Test expired token + { + // Expired jwt + tokenStr, err := jwt.NewWithClaims(jwt.SigningMethodHS256, CustomClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "https://auth.vatsim.net/api/fsd-jwt", + Subject: "1000000", + Audience: []string{"fsd-live"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-60 * time.Second)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-120 * time.Second)), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: "id", + }, + ControllerRating: 0, + PilotRating: 0, + }).SignedString(JWTKey) + + assert.Nil(t, err) + + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: tokenStr, + NetworkRating: 1, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + expectedPDU := protocol.NewGenericFSDError(protocol.InvalidLogonError) + + assert.Equal(t, expectedPDU.Serialize(), responseMsg) + conn.Close() + } + + // Test token with invalid CID + { + // jwt with invalid CID + token := jwt.NewWithClaims(jwt.SigningMethodHS256, CustomClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "https://auth.vatsim.net/api/fsd-jwt", + Subject: "9999999", + Audience: []string{"fsd-live"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(420 * time.Second)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-120 * time.Second)), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: "id", + }, + ControllerRating: 0, + PilotRating: 0, + }) + + tokenString, err := token.SignedString(JWTKey) + assert.Nil(t, err) + + conn, err := net.Dial("tcp", "localhost:6809") + assert.Nil(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + assert.Nil(t, err) + + reader := bufio.NewReader(conn) + serverIdent, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + assert.True(t, strings.HasPrefix(serverIdent, "$DISERVER:CLIENT:")) + + clientIdentPDU := protocol.ClientIdentificationPDU{ + From: "N123", + To: "SERVER", + ClientID: 35044, + ClientName: "vPilot", + MajorVersion: 3, + MinorVersion: 8, + CID: 1000000, + SysUID: -99999, + InitialChallenge: "0123456789abcdef", + } + + addPilotPDU := protocol.AddPilotPDU{ + From: "N123", + To: "SERVER", + CID: 1000000, + Token: tokenString, + NetworkRating: 1, + ProtocolRevision: 101, + SimulatorType: 2, + RealName: "real name", + } + + _, err = conn.Write([]byte(clientIdentPDU.Serialize())) + assert.Nil(t, err) + _, err = conn.Write([]byte(addPilotPDU.Serialize())) + assert.Nil(t, err) + + responseMsg, err := reader.ReadString('\n') + assert.Nil(t, err) + assert.NotEmpty(t, serverIdent) + + expectedPDU := protocol.NewGenericFSDError(protocol.InvalidLogonError) + + assert.Equal(t, expectedPDU.Serialize(), responseMsg) + conn.Close() + } +} diff --git a/fsd_jwt_auth_test.go b/fsd_jwt_auth_test.go new file mode 100644 index 0000000..99aa2ff --- /dev/null +++ b/fsd_jwt_auth_test.go @@ -0,0 +1,99 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "net/http" + "os" + "testing" + "time" +) + +func doJwtRequest(t *testing.T, url string, cid int, password string) *JwtResponse { + jwtRequest := JwtRequest{ + CID: fmt.Sprintf("%d", cid), + Password: password, + } + jsonData, err := json.Marshal(jwtRequest) + assert.Nil(t, err) + + client := http.Client{} + resp, err := client.Post(url, "application/json", bytes.NewReader(jsonData)) + assert.Nil(t, err) + + buf := new(bytes.Buffer) + _, err = buf.ReadFrom(resp.Body) + assert.Nil(t, err) + var jwtResponse JwtResponse + err = json.Unmarshal(buf.Bytes(), &jwtResponse) + assert.Nil(t, err) + + return &jwtResponse +} + +func TestServeJwtAuthTokens(t *testing.T) { + SC = &ServerConfig{ + FsdListenAddr: "localhost:6809", + HttpListenAddr: "localhost:9086", + HttpsEnabled: false, + DatabaseFile: "./test.db", + MOTD: "", + } + os.Remove(SC.DatabaseFile) + defer os.Remove(SC.DatabaseFile) + + configureDatabase() + configureJwt() + configurePostOffice() + + addUserToDatabase(t, 1000000, "12345", 1) + + // Start http server + httpCtx, cancelHttp := context.WithCancel(context.Background()) + go StartHttpServer(httpCtx) + defer cancelHttp() + time.Sleep(50 * time.Millisecond) + + // Test successful request + { + jwtResponse := doJwtRequest(t, "http://localhost:9086/api/fsd-jwt", 1000000, "12345") + + assert.True(t, jwtResponse.Success) + assert.NotEmpty(t, jwtResponse.Token) + assert.Empty(t, jwtResponse.ErrorMsg) + + claims := jwt.MapClaims{} + token, err := jwt.ParseWithClaims(jwtResponse.Token, &claims, func(token *jwt.Token) (interface{}, error) { + return JWTKey, nil + }) + assert.Nil(t, err) + assert.NotNil(t, claims) + exp, err := token.Claims.GetExpirationTime() + assert.Nil(t, err) + iat, err := token.Claims.GetIssuedAt() + assert.Nil(t, err) + assert.True(t, exp.Sub(iat.Time) == 420*time.Second) + } + + // Test invalid CID + { + jwtResponse := doJwtRequest(t, "http://localhost:9086/api/fsd-jwt", 9999999, "12345") + + assert.False(t, jwtResponse.Success) + assert.Empty(t, jwtResponse.Token) + assert.Equal(t, jwtResponse.ErrorMsg, "User not found") + } + + // Test invalid password + { + jwtResponse := doJwtRequest(t, "http://localhost:9086/api/fsd-jwt", 1000000, "54321") + + assert.False(t, jwtResponse.Success) + assert.Empty(t, jwtResponse.Token) + assert.Equal(t, jwtResponse.ErrorMsg, "Password is incorrect") + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6174aae --- /dev/null +++ b/go.mod @@ -0,0 +1,26 @@ +module github.com/renorris/openfsd + +go 1.22.1 + +require ( + github.com/go-playground/validator/v10 v10.19.0 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/mattn/go-sqlite3 v1.14.22 + github.com/mmcloughlin/geohash v0.10.0 + github.com/sethvargo/go-envconfig v1.0.1 + github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.21.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..889c4d1 --- /dev/null +++ b/go.sum @@ -0,0 +1,40 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4= +github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mmcloughlin/geohash v0.10.0 h1:9w1HchfDfdeLc+jFEf/04D27KP7E2QmpDu52wPbJWRE= +github.com/mmcloughlin/geohash v0.10.0/go.mod h1:oNZxQo5yWJh0eMQEP/8hwQuVx9Z9tjwFUqcTB1SmG0c= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sethvargo/go-envconfig v1.0.1 h1:9wglip/5fUfaH0lQecLM8AyOClMw0gT0A9K2c2wozao= +github.com/sethvargo/go-envconfig v1.0.1/go.mod h1:OKZ02xFaD3MvWBBmEW45fQr08sJEsonGrrOdicvQmQA= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/http_server.go b/http_server.go new file mode 100644 index 0000000..42a45d3 --- /dev/null +++ b/http_server.go @@ -0,0 +1,509 @@ +package main + +import ( + "bytes" + "context" + "crypto/rand" + "database/sql" + _ "embed" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/golang-jwt/jwt/v5" + _ "github.com/mattn/go-sqlite3" + "github.com/renorris/openfsd/protocol" + "golang.org/x/crypto/bcrypt" + "log" + "net/http" + "strconv" + "time" +) + +//go:embed dashboard.html +var dashboardHtml []byte + +type JwtRequest struct { + CID string `json:"cid"` + Password string `json:"password"` +} + +type JwtResponse struct { + Success bool `json:"success"` + Token string `json:"token,omitempty"` + ErrorMsg string `json:"error_msg,omitempty"` +} + +type CustomClaims struct { + jwt.RegisteredClaims + ControllerRating int `json:"controller_rating"` + PilotRating int `json:"pilot_rating"` +} + +type UserApiResponse struct { + Success bool `json:"success"` + Message string `json:"msg"` + User *FSDUserRecord `json:"user_record,omitempty"` +} + +func fsdJwtApiHandler(w http.ResponseWriter, r *http.Request) { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + var jwtRequest JwtRequest + if err = json.Unmarshal(buf.Bytes(), &jwtRequest); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + cid, err := strconv.Atoi(jwtRequest.CID) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + userRecord, userRecordErr := GetUserRecord(DB, cid) + if userRecordErr != nil && !errors.Is(userRecordErr, sql.ErrNoRows) { + w.WriteHeader(http.StatusInternalServerError) + return + } + + // If user not found + if errors.Is(userRecordErr, sql.ErrNoRows) { + jwtResponse := JwtResponse{ + Success: false, + Token: "", + ErrorMsg: "User not found", + } + + body, marshalErr := json.Marshal(jwtResponse) + if marshalErr != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + _, writeErr := w.Write(body) + if writeErr != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + return + } + + // Verify password + userRecordErr = bcrypt.CompareHashAndPassword([]byte(userRecord.Password), []byte(jwtRequest.Password)) + if userRecordErr != nil { // Password didn't match + jwtResponse := JwtResponse{ + Success: false, + Token: "", + ErrorMsg: "Password is incorrect", + } + + body, err := json.Marshal(jwtResponse) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + _, writeErr := w.Write(body) + if writeErr != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + return + } + + // Else send a login token + idBytes := make([]byte, 16) + _, userRecordErr = rand.Read(idBytes) + if userRecordErr != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + idStr := base64.StdEncoding.EncodeToString(idBytes) + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, CustomClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "openfsd", + Subject: jwtRequest.CID, + Audience: []string{"fsd-live"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(420 * time.Second)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-120 * time.Second)), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: idStr, + }, + ControllerRating: userRecord.Rating, + PilotRating: 0, + }) + + tokenString, userRecordErr := token.SignedString(JWTKey) + if userRecordErr != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + jwtResponse := JwtResponse{ + Success: true, + Token: tokenString, + ErrorMsg: "", + } + + body, userRecordErr := json.Marshal(jwtResponse) + if userRecordErr != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + _, writeErr := w.Write(body) + if writeErr != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func userAPIHandler(w http.ResponseWriter, r *http.Request) { + + // Check for a valid token cookie + cookie, err := r.Cookie("token") + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + + // Validate token + claims := jwt.RegisteredClaims{} + token, err := jwt.ParseWithClaims(cookie.Value, &claims, func(token *jwt.Token) (interface{}, error) { + return JWTKey, nil + }) + + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + + audience, err := token.Claims.GetAudience() + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + + audienceValid := false + for _, audClaim := range audience { + if audClaim == "administrator-dashboard" { + audienceValid = true + break + } + } + + if !audienceValid { + w.WriteHeader(http.StatusForbidden) + return + } + + // Handle the request + switch r.Method { + case "GET": + cidStr := r.FormValue("cid") + + cid, err := strconv.Atoi(cidStr) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + userRecord, err := GetUserRecord(DB, cid) + if err != nil { + res := UserApiResponse{ + Success: false, + Message: "Error: user not found", + } + + resBytes, err := json.Marshal(res) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(resBytes) + return + } + + // omit password + userRecord.Password = "" + + res := UserApiResponse{ + Success: true, + Message: "Success", + User: userRecord, + } + + resBytes, err := json.Marshal(res) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(resBytes) + return + case "POST": + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + req := FSDUserRecord{} + err = json.Unmarshal(buf.Bytes(), &req) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + bcryptBytes, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10) + passwordHash := string(bcryptBytes) + + record, err := AddUserRecordSequential(DB, passwordHash, req.Rating, req.RealName) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + record.Password = "" + + res := UserApiResponse{ + Success: true, + Message: "Success: added user with CID " + fmt.Sprintf("%d", record.CID), + User: record, + } + + resBytes, err := json.Marshal(res) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(resBytes) + return + case "PATCH": + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + req := FSDUserRecord{} + err = json.Unmarshal(buf.Bytes(), &req) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if len(req.Password) > 0 { + bcryptBytes, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + req.Password = string(bcryptBytes) + } + + err = UpdateUserRecord(DB, &req) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + req.Password = "" + + res := UserApiResponse{ + Success: true, + Message: "Success: updated user with CID " + fmt.Sprintf("%d", req.CID), + User: &req, + } + + resBytes, err := json.Marshal(res) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(resBytes) + return + } +} + +func dashboardHandler(w http.ResponseWriter, r *http.Request) { + + // If we have a valid cookie containing a valid JWT, send the dashboard page + cookie, err := r.Cookie("token") + if err == nil { + claims := jwt.RegisteredClaims{} + token, err := jwt.ParseWithClaims(cookie.Value, &claims, func(token *jwt.Token) (interface{}, error) { + return JWTKey, nil + }) + + // If the token is invalid, remove the cookie and send back basic auth + if err != nil { + cookie.Expires = time.Unix(0, 0) + cookie.Value = "" + http.SetCookie(w, cookie) + + w.Header().Add("WWW-Authenticate", `Basic realm="dashboard", charset="UTF-8"`) + w.WriteHeader(http.StatusUnauthorized) + return + } + + audience, err := token.Claims.GetAudience() + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + + audienceValid := false + for _, audClaim := range audience { + if audClaim == "administrator-dashboard" { + audienceValid = true + break + } + } + + if !audienceValid { + w.WriteHeader(http.StatusForbidden) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(dashboardHtml) + return + } + + // If we don't have a cookie, check if we were sent basic auth information + username, pwd, ok := r.BasicAuth() + + // If not, send the basic auth header + if !ok { + w.Header().Add("WWW-Authenticate", `Basic realm="dashboard", charset="UTF-8"`) + w.WriteHeader(http.StatusUnauthorized) + return + } + + cid, err := strconv.Atoi(username) + if err != nil { + w.Header().Add("WWW-Authenticate", `Basic realm="dashboard", charset="UTF-8"`) + w.WriteHeader(http.StatusUnauthorized) + return + } + + userRecord, err := GetUserRecord(DB, cid) + if err != nil { + w.Header().Add("Content-Type", "text/plain") + w.Header().Add("WWW-Authenticate", `Basic realm="dashboard", charset="UTF-8"`) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("user not found")) + return + } + + if userRecord.Rating < protocol.NetworkRatingSUP { + w.Header().Add("Content-Type", "text/plain") + w.Header().Add("WWW-Authenticate", `Basic realm="dashboard", charset="UTF-8"`) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("rating too low")) + return + } + + err = bcrypt.CompareHashAndPassword([]byte(userRecord.Password), []byte(pwd)) + if err != nil { + w.Header().Add("Content-Type", "text/plain") + w.Header().Add("WWW-Authenticate", `Basic realm="dashboard", charset="UTF-8"`) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("user not found")) + return + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + Issuer: "openfsd", + Subject: string(rune(userRecord.CID)), + Audience: []string{"administrator-dashboard"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(12 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }) + + tokenStr, err := token.SignedString(JWTKey) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + http.SetCookie(w, &http.Cookie{ + Name: "token", + Value: tokenStr, + Expires: time.Now().Add(12 * time.Hour), + }) + + w.WriteHeader(http.StatusOK) + w.Write(dashboardHtml) +} + +func defaultHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) +} + +func StartHttpServer(ctx context.Context) { + mux := http.NewServeMux() + mux.HandleFunc("POST /api/fsd-jwt", fsdJwtApiHandler) + mux.HandleFunc("GET /dashboard", dashboardHandler) + mux.HandleFunc("/user", userAPIHandler) + mux.HandleFunc("/", defaultHandler) + server := &http.Server{Addr: SC.HttpListenAddr, Handler: mux} + go func() { + if SC.HttpsEnabled { + if err := server.ListenAndServeTLS(SC.TLSCertFile, SC.TLSKeyFile); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + log.Fatal("https server error:\n" + err.Error()) + } + } + } else { + if err := server.ListenAndServe(); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + log.Fatal("http server error:\n" + err.Error()) + } + } + } + + }() + + log.Println("HTTP listening") + + // Wait for context done signal + <-ctx.Done() + + // Shutdown server + if err := server.Shutdown(context.Background()); err != nil { + log.Fatal("http server shutdown error:\n" + err.Error()) + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..6a8b191 --- /dev/null +++ b/main.go @@ -0,0 +1,199 @@ +package main + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "fmt" + "github.com/go-playground/validator/v10" + _ "github.com/mattn/go-sqlite3" + "github.com/renorris/openfsd/protocol" + "github.com/sethvargo/go-envconfig" + "golang.org/x/crypto/bcrypt" + "io" + "log" + "net" + "os" + "os/signal" + "syscall" +) + +type ServerConfig struct { + FsdListenAddr string `env:"FSD_ADDR, default=0.0.0.0:6809"` + HttpListenAddr string `env:"HTTP_ADDR, default=0.0.0.0:9086"` + HttpsEnabled bool `env:"HTTPS_ENABLED, default=false"` + TLSCertFile string `env:"TLS_CERT_FILE"` + TLSKeyFile string `env:"TLS_KEY_FILE"` + DatabaseFile string `env:"DATABASE_FILE, default=./fsd.db"` + MOTD string `env:"MOTD, default=openfsd"` +} + +var SC *ServerConfig +var DB *sql.DB +var PO *PostOffice +var JWTKey []byte + +const TableCreateStatement = ` + CREATE TABLE IF NOT EXISTS users ( + cid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + password CHAR(60) NOT NULL, + rating TINYINT NOT NULL, + real_name VARCHAR(64) NOT NULL, + creation_time DATETIME NOT NULL + ); + ` + +func StartFSDServer(fsdCtx context.Context) { + addr, err := net.ResolveTCPAddr("tcp4", SC.FsdListenAddr) + if err != nil { + log.Fatal("Error resolving address: " + err.Error()) + } + + listener, err := net.ListenTCP("tcp4", addr) + if err != nil { + log.Fatal("Error listening: " + err.Error()) + } + defer func() { + closeErr := listener.Close() + if closeErr != nil { + log.Println("Error closing listener: " + closeErr.Error()) + } + }() + + // Listener goroutine + inConnections := make(chan *net.TCPConn) + go func(outConnections chan<- *net.TCPConn) { + defer close(outConnections) + for { + conn, listenerErr := listener.AcceptTCP() + if listenerErr != nil { + return + } + + select { + case outConnections <- conn: + case <-fsdCtx.Done(): + return + } + } + }(inConnections) + + log.Println("FSD listening") + + for { + select { + case conn, ok := <-inConnections: + if !ok { + return + } + go HandleConnection(conn) + case <-fsdCtx.Done(): + return + } + } +} + +func configureProtocolValidator() { + protocol.V = validator.New(validator.WithRequiredStructEnabled()) +} + +func configurePostOffice() { + PO = &PostOffice{ + clientRegistry: make(map[string]*FSDClient), + supervisorRegistry: make(map[string]*FSDClient), + geohashRegistry: make(map[string][]*FSDClient), + } +} + +func configureJwt() { + idBytes := make([]byte, 64) + _, err := rand.Read(idBytes) + if err != nil { + log.Fatal(err) + } + JWTKey = idBytes +} + +func configureDatabase() { + db, err := sql.Open("sqlite3", SC.DatabaseFile) + if err != nil { + log.Panic(err) + } + + _, err = db.Exec(TableCreateStatement) + if err != nil { + log.Panic(err) + } + + // Check if the users table is empty + // (is this the first time we've started up using this database?) + usersEmpty, err := IsUsersTableEmpty(db) + if err != nil { + log.Panic(err) + } + + // If it's empty, add a default admin user + if usersEmpty { + buf := make([]byte, 8) // since each byte is 2 hex characters + if _, err = io.ReadFull(rand.Reader, buf); err != nil { + log.Panic(err) + } + + pwd := hex.EncodeToString(buf) + + cid := 100000 + pwdHash, err := bcrypt.GenerateFromPassword([]byte(pwd), 10) + if err != nil { + log.Panic(err) + } + rating := protocol.NetworkRatingADM + realName := "Default Administrator" + + record, err := AddUserRecord(db, cid, string(pwdHash), rating, realName) + if err != nil { + log.Panic(err) + } + + fmt.Printf("Added user: %s CID: %d Password: %s Rating: Administrator\n", record.RealName, record.CID, pwd) + } + + // Set global DB variable + DB = db +} + +func setupServerConfig() { + ctx := context.Background() + + var c ServerConfig + if err := envconfig.Process(ctx, &c); err != nil { + log.Fatal(err) + } + + SC = &c +} + +func main() { + setupServerConfig() + configureProtocolValidator() + + configureDatabase() + defer DB.Close() + + configureJwt() + configurePostOffice() + + httpCtx, cancelHttp := context.WithCancel(context.Background()) + go StartHttpServer(httpCtx) + defer cancelHttp() + + fsdCtx, cancelFsd := context.WithCancel(context.Background()) + go StartFSDServer(fsdCtx) + defer cancelFsd() + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + + // Wait for OS done signal + <-done +} diff --git a/post_office.go b/post_office.go new file mode 100644 index 0000000..a75d3b9 --- /dev/null +++ b/post_office.go @@ -0,0 +1,259 @@ +package main + +import ( + "errors" + "github.com/mmcloughlin/geohash" + "github.com/renorris/openfsd/protocol" + "sync" +) + +// PostOffice handles the routing of messages between clients +type PostOffice struct { + clientRegistry map[string]*FSDClient + supervisorRegistry map[string]*FSDClient + geohashRegistry map[string][]*FSDClient + lock sync.RWMutex +} + +const ( + MailTypeDirect = iota + MailTypeBroadcastRanged + MailTypeBroadcastAll + MailTypeBroadcastSupervisors +) + +// Mail holds messages to be passed between clients +type Mail struct { + Type int + Source *FSDClient + Recipients []string + Packets []string +} + +func NewMail(source *FSDClient) *Mail { + return &Mail{ + Type: 0, + Source: source, + Recipients: nil, + Packets: nil, + } +} + +func (m *Mail) SetType(mailType int) { + m.Type = mailType +} + +func (m *Mail) AddRecipient(callsign string) { + if m.Recipients == nil { + m.Recipients = make([]string, 0) + } + m.Recipients = append(m.Recipients, callsign) +} + +func (m *Mail) AddPacket(packet string) { + if m.Packets == nil { + m.Packets = make([]string, 0) + } + m.Packets = append(m.Packets, packet) +} + +type FSDClientNode struct { + Client *FSDClient + Next *FSDClientNode +} + +var ( + CallsignAlreadyRegisteredError = callsignAlreadyRegisteredError() + CallsignNotRegisteredError = callsignNotRegisteredError() +) + +func callsignAlreadyRegisteredError() error { return errors.New("callsign already registered") } +func callsignNotRegisteredError() error { return errors.New("callsign not registered") } + +// RegisterCallsign registers a callsign to the post office, making it a valid recipient for other clients +func (p *PostOffice) RegisterCallsign(callsign string, client *FSDClient) error { + p.lock.Lock() + defer p.lock.Unlock() + + // Check if callsign already exists in the registry + _, ok := p.clientRegistry[callsign] + if ok { + return CallsignAlreadyRegisteredError + } + + // Otherwise, add the client to the registry + p.clientRegistry[callsign] = client + + // If the client is a supervisor, add them to the supervisor registry + if client.NetworkRating == protocol.NetworkRatingSUP { + p.supervisorRegistry[callsign] = client + } + + return nil +} + +// DeregisterCallsign removes a callsign from the post office. +// Returns CallsignNotRegisteredError if the callsign is not registered. +func (p *PostOffice) DeregisterCallsign(callsign string) error { + p.lock.Lock() + defer p.lock.Unlock() + + // Check if callsign exists in registry + client, ok := p.clientRegistry[callsign] + if !ok { + return CallsignNotRegisteredError + } + + // Delete the entry + delete(p.clientRegistry, callsign) + + // If the client is a supervisor, delete from supervisor registry + if client.NetworkRating == protocol.NetworkRatingSUP { + delete(p.supervisorRegistry, callsign) + } + + // Remove client from geohash registry + p.removeClientFromGeohashRegistry(client, client.CurrentGeohash) + + return nil +} + +// SendMail forwards Mail to its recipients +func (p *PostOffice) SendMail(messages []Mail) { + p.lock.RLock() + defer p.lock.RUnlock() + + for _, msg := range messages { + switch msg.Type { + case MailTypeDirect: + for _, recipient := range msg.Recipients { + fsdClient, ok := p.clientRegistry[recipient] + // If the callsign doesn't exist, drop the message + if !ok { + continue + } + for _, packet := range msg.Packets { + // Do not block writing to mailbox, avoiding potential deadlock + select { + case fsdClient.Mailbox <- packet: + default: + } + } + } + case MailTypeBroadcastRanged: + neighbors := geohash.Neighbors(msg.Source.CurrentGeohash) + searchHashes := append(neighbors, msg.Source.CurrentGeohash) + for _, hash := range searchHashes { + clients, ok := p.geohashRegistry[hash] + if !ok { + continue + } + for _, client := range clients { + if msg.Source != client { + for _, packet := range msg.Packets { + select { + case client.Mailbox <- packet: + default: + } + } + } + } + } + case MailTypeBroadcastAll: + for _, fsdClient := range p.clientRegistry { + if msg.Source != fsdClient { + for _, packet := range msg.Packets { + select { + case fsdClient.Mailbox <- packet: + default: + } + } + } + } + case MailTypeBroadcastSupervisors: + for _, fsdClient := range p.supervisorRegistry { + for _, packet := range msg.Packets { + if msg.Source != fsdClient { + select { + case fsdClient.Mailbox <- packet: + default: + } + } + } + } + } + + } +} + +func (p *PostOffice) removeClientFromGeohashRegistry(client *FSDClient, hash string) { + clientList, ok := p.geohashRegistry[client.CurrentGeohash] + if !ok { + return + } + + // Attempt to find the client in the list + for i := 0; i < len(clientList); i++ { + if clientList[i] == client { + // Remove the client from the slice + clientList = append(clientList[:i], clientList[i+1:]...) + break + } + } + + // If that was the last client in the list, delete it from the map and return + if len(clientList) == 0 { + delete(p.geohashRegistry, client.CurrentGeohash) + return + } + + // Re-allocate the slice and rewrite the map entry + newClientList := make([]*FSDClient, len(clientList)) + copy(newClientList, clientList) + + p.geohashRegistry[client.CurrentGeohash] = newClientList +} + +// SetLocation updates the internal geohash tracking state for a client, if necessary +func (p *PostOffice) SetLocation(client *FSDClient, lat, lng float64) { + // Check if the geohash has changed since we last updated + hash := geohash.EncodeWithPrecision(lat, lng, 3) + if hash == client.CurrentGeohash { + return + } + + p.lock.Lock() + defer p.lock.Unlock() + + // Remove client from old geohash bucket + p.removeClientFromGeohashRegistry(client, client.CurrentGeohash) + + // Find the new client list + newClientList, ok := p.geohashRegistry[hash] + // Create the slice if necessary + if !ok { + newClientList = make([]*FSDClient, 0) + } + + // Add client to the list + newClientList = append(newClientList, client) + + // Put it back on the registry + p.geohashRegistry[hash] = newClientList + + // Set our new current geohash + client.CurrentGeohash = hash +} + +// GetClient finds an *FSDClient for a callsign string. +func (p *PostOffice) GetClient(callsign string) (*FSDClient, error) { + p.lock.RLock() + defer p.lock.RUnlock() + + client, ok := p.clientRegistry[callsign] + if !ok { + return nil, CallsignNotRegisteredError + } + + return client, nil +} diff --git a/post_office_test.go b/post_office_test.go new file mode 100644 index 0000000..e8c605d --- /dev/null +++ b/post_office_test.go @@ -0,0 +1,239 @@ +package main + +import ( + "github.com/renorris/openfsd/protocol" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestPostOffice(t *testing.T) { + configurePostOffice() + + c1 := FSDClient{ + Conn: nil, + Reader: nil, + AuthVerify: nil, + AuthSelf: nil, + NetworkRating: protocol.NetworkRatingOBS, + Mailbox: make(chan string, 16), + } + + c2 := FSDClient{ + Conn: nil, + Reader: nil, + AuthVerify: nil, + AuthSelf: nil, + NetworkRating: protocol.NetworkRatingSUP, + Mailbox: make(chan string, 16), + } + + // Test RegisterCallsign + + // Register "1" + { + err := PO.RegisterCallsign("1", &c1) + assert.Nil(t, err) + } + + // Register "1" twice + { + err := PO.RegisterCallsign("1", &c1) + assert.NotNil(t, err) + assert.ErrorIs(t, CallsignAlreadyRegisteredError, err) + } + + // Register "2" + { + err := PO.RegisterCallsign("2", &c2) + assert.Nil(t, err) + } + + // Register "2" twice + { + err := PO.RegisterCallsign("2", &c2) + assert.NotNil(t, err) + assert.ErrorIs(t, CallsignAlreadyRegisteredError, err) + } + + // Test SendMail + { + m := Mail{ + Type: MailTypeDirect, + Source: &c1, + Recipients: []string{"2"}, + Packets: []string{"hello"}, + } + + PO.SendMail([]Mail{m}) + + assert.Empty(t, c1.Mailbox) + assert.NotEmpty(t, c2.Mailbox) + + msg := <-c2.Mailbox + assert.Equal(t, "hello", msg) + + assert.Empty(t, c1.Mailbox) + assert.Empty(t, c2.Mailbox) + } + + { + m := Mail{ + Type: MailTypeBroadcastAll, + Source: &c1, + Recipients: []string{"2"}, + Packets: []string{"hello"}, + } + + PO.SendMail([]Mail{m}) + + assert.Empty(t, c1.Mailbox) + assert.NotEmpty(t, c2.Mailbox) + + { + msg := <-c2.Mailbox + assert.Equal(t, "hello", msg) + } + } + + // Test supervisor message + { + m := Mail{ + Type: MailTypeBroadcastSupervisors, + Source: &c1, + Recipients: nil, + Packets: []string{"hello"}, + } + + PO.SendMail([]Mail{m}) + assert.Empty(t, c1.Mailbox) + assert.NotEmpty(t, c2.Mailbox) + + { + msg := <-c2.Mailbox + assert.Equal(t, "hello", msg) + } + } + + // Test ranged broadcast mail + { + // c1 and c2 in range: + PO.SetLocation(&c1, 45.0, 45.0) + PO.SetLocation(&c2, 45.0, 45.0) + + m := Mail{ + Type: MailTypeBroadcastRanged, + Source: &c1, + Recipients: nil, + Packets: []string{"hello"}, + } + + PO.SendMail([]Mail{m}) + + // c2 should receive the broadcast + assert.Empty(t, c1.Mailbox) + assert.NotEmpty(t, c2.Mailbox) + + { + msg := <-c2.Mailbox + assert.Equal(t, "hello", msg) + } + + // Move c2 far away from c1 + PO.SetLocation(&c2, -45.0, -45.0) + + // Try again, nobody should get anything now + PO.SendMail([]Mail{m}) + + assert.Empty(t, c1.Mailbox) + assert.Empty(t, c2.Mailbox) + + // Move back + PO.SetLocation(&c2, 45.0, 45.0) + + // Should get the message again + PO.SendMail([]Mail{m}) + assert.Empty(t, c1.Mailbox) + assert.NotEmpty(t, c2.Mailbox) + + { + msg := <-c2.Mailbox + assert.Equal(t, "hello", msg) + } + + // Move c2 to a "neighbor" geohash cell + PO.SetLocation(&c2, 45.7, 47.1) + + // c1 geohash: v00 + // c2 geohash: v01 (neighbor) + + PO.SendMail([]Mail{m}) + assert.Empty(t, c1.Mailbox) + assert.NotEmpty(t, c2.Mailbox) + + { + msg := <-c2.Mailbox + assert.Equal(t, "hello", msg) + } + } + + // Test DeregisterCallsign + + // "3" is not registered + { + err := PO.DeregisterCallsign("3") + assert.NotNil(t, err) + assert.ErrorIs(t, CallsignNotRegisteredError, err) + } + + // Unavailable client mailbox should not block SendMail + { + m := Mail{ + Type: MailTypeDirect, + Source: &c1, + Recipients: []string{"2"}, + Packets: []string{"hello"}, + } + + c2.Mailbox = nil + + timer := time.NewTimer(100 * time.Millisecond) + done := make(chan interface{}) + go func() { + PO.SendMail([]Mail{m}) + close(done) + }() + + select { + case <-timer.C: + assert.Fail(t, "SendMail timeout") + case <-done: + } + } + + // Deregister "2" + { + err := PO.DeregisterCallsign("2") + assert.Nil(t, err) + } + + // Deregister "2" twice + { + err := PO.DeregisterCallsign("2") + assert.NotNil(t, err) + assert.ErrorIs(t, CallsignNotRegisteredError, err) + } + + // Deregister "1" + { + err := PO.DeregisterCallsign("1") + assert.Nil(t, err) + } + + // Deregister "1" twice + { + err := PO.DeregisterCallsign("1") + assert.NotNil(t, err) + assert.ErrorIs(t, CallsignNotRegisteredError, err) + } +} diff --git a/processor.go b/processor.go new file mode 100644 index 0000000..79e519b --- /dev/null +++ b/processor.go @@ -0,0 +1,123 @@ +package main + +import ( + "errors" + "github.com/renorris/openfsd/protocol" + "strings" +) + +// ProcessorResult represents the result of a handler function +type ProcessorResult struct { + Replies []string + Mail []Mail + ShouldDisconnect bool +} + +func NewProcessorResult() *ProcessorResult { + return &ProcessorResult{ + Replies: nil, + Mail: nil, + ShouldDisconnect: false, + } +} + +// Disconnect sets the flag to disconnect the client +func (r *ProcessorResult) Disconnect(flag bool) { + r.ShouldDisconnect = flag +} + +// AddReply adds a packet to send back to the client +func (r *ProcessorResult) AddReply(packet string) { + if r.Replies == nil { + r.Replies = make([]string, 0) + } + r.Replies = append(r.Replies, packet) +} + +// AddMail adds mail to be sent to other clients +func (r *ProcessorResult) AddMail(mail Mail) { + if r.Mail == nil { + r.Mail = make([]Mail, 0) + } + r.Mail = append(r.Mail, mail) +} + +// Processor represents a function to process an incoming FSD packet +type Processor func(client *FSDClient, rawPacket string) *ProcessorResult + +// InvalidPacketError means the packet was not recognized by the parser +func newInvalidPacketError() error { + return errors.New("Invalid packet") +} + +var ( + InvalidPacketError = newInvalidPacketError() +) + +func GetProcessor(rawPacket string) (Processor, error) { + rawPacket = strings.TrimSuffix(rawPacket, "\r\n") + if len(rawPacket) < 3 { + return nil, InvalidPacketError + } + + fields := strings.Split(rawPacket, protocol.Delimeter) + + switch rawPacket[0] { + case '^': + return FastPilotPositionProcessor, nil + case '@': + return PilotPositionProcessor, nil + case '#', '$': + pduID := rawPacket[0:3] + switch pduID { + case "$CQ": + return ClientQueryProcessor, nil + case "$CR": + return ClientQueryResponseProcessor, nil + case "#SL": + return FastPilotPositionProcessor, nil + case "#ST": + return FastPilotPositionProcessor, nil + case "$AX": + // TODO: implement METAR request + case "$PI": + return PingProcessor, nil + case "$ZC": + return AuthChallengeProcessor, nil + case "$ZR": + return AuthChallengeResponseProcessor, nil + case "$!!": + return KillRequestProcessor, nil + case "#DP": + return DeletePilotProcessor, nil + case "#SB": + if len(fields) < 3 { + return nil, InvalidPacketError + } + if fields[2] == "PIR" { + return PlaneInfoRequestProcessor, nil + } + if fields[2] == "PI" && len(fields) > 3 && fields[3] == "GEN" { + return PlaneInfoResponseProcessor, nil + } + case "#TM": + if len(fields) > 3 { + return nil, InvalidPacketError + } + switch fields[1] { + case "*": + return BroadcastMessageProcessor, nil + case "*S": + return WallopMessageProcessor, nil + default: + if len(fields[1]) > 0 && fields[1][0:1] == "@" { + return RadioMessageProcessor, nil + } else { + return DirectMessageProcessor, nil + } + } + } + } + + return nil, InvalidPacketError +} diff --git a/processor_defs.go b/processor_defs.go new file mode 100644 index 0000000..cfe2772 --- /dev/null +++ b/processor_defs.go @@ -0,0 +1,624 @@ +package main + +import ( + "errors" + "github.com/renorris/openfsd/protocol" + "github.com/renorris/openfsd/vatsimauth" + "log" + "strings" +) + +func PilotPositionProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParsePilotPositionPDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Update location for post office + PO.SetLocation(client, pdu.Lat, pdu.Lng) + + result := NewProcessorResult() + + // Check if we should update SendFastEnabled + if client.SendFastEnabled && pdu.GroundSpeed == 0 { + client.SendFastEnabled = false + disableSendFastPDU := protocol.SendFastPDU{ + From: protocol.ServerCallsign, + To: client.Callsign, + DoSendFast: false, + } + result.AddReply(disableSendFastPDU.Serialize()) + } else if !client.SendFastEnabled && pdu.GroundSpeed > 0 { + client.SendFastEnabled = true + enableSendFastPDU := protocol.SendFastPDU{ + From: protocol.ServerCallsign, + To: client.Callsign, + DoSendFast: true, + } + result.AddReply(enableSendFastPDU.Serialize()) + } + + // Broadcast position update + mail := NewMail(client) + mail.SetType(MailTypeBroadcastRanged) + mail.AddPacket(rawPacket) + + result.AddMail(*mail) + return result +} + +func ClientQueryProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseClientQueryPDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + switch pdu.To { + case protocol.ServerCallsign: + if pdu.QueryType == protocol.ClientQueryPublicIP { + ip := strings.Split(client.Conn.RemoteAddr().String(), ":")[0] + responsePDU := protocol.ClientQueryResponsePDU{ + From: protocol.ServerCallsign, + To: client.Callsign, + QueryType: protocol.ClientQueryPublicIP, + Payload: ip, + } + result := NewProcessorResult() + result.AddReply(responsePDU.Serialize()) + return result + } + + return NewProcessorResult() + + case protocol.ClientQueryBroadcastRecipient, protocol.ClientQueryBroadcastRecipientPilots: + mail := NewMail(client) + mail.SetType(MailTypeBroadcastRanged) + mail.AddRecipient(pdu.To) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result + } + + // Assume direct message client query + mail := NewMail(client) + mail.SetType(MailTypeDirect) + mail.AddRecipient(pdu.To) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result + +} + +func ClientQueryResponseProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseClientQueryResponsePDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + if pdu.To == protocol.ServerCallsign { + return NewProcessorResult() + } + + mail := NewMail(client) + mail.SetType(MailTypeDirect) + mail.AddRecipient(pdu.To) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result +} + +func PlaneInfoRequestProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParsePlaneInfoRequestPDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + mail := NewMail(client) + mail.SetType(MailTypeDirect) + mail.AddRecipient(pdu.To) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result +} + +func PlaneInfoResponseProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParsePlaneInfoResponsePDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + mail := NewMail(client) + mail.SetType(MailTypeDirect) + mail.AddRecipient(pdu.To) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result +} + +func FastPilotPositionProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseFastPilotPositionPDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Update location for post office if slow/stopped type + switch pdu.Type { + case protocol.FastPilotPositionTypeSlow, protocol.FastPilotPositionTypeStopped: + PO.SetLocation(client, pdu.Lat, pdu.Lng) + } + + // Broadcast position update + mail := NewMail(client) + mail.SetType(MailTypeBroadcastRanged) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + return result +} + +func AuthChallengeProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseAuthChallengePDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Generate response for the client's challenge + challengeResponse := client.AuthSelf.GenerateResponse(pdu.Challenge) + client.AuthSelf.UpdateState(challengeResponse) + challengeResponsePDU := protocol.AuthChallengeResponsePDU{ + From: protocol.ServerCallsign, + To: client.Callsign, + ChallengeResponse: challengeResponse, + } + + // Send the client a new auth challenge + client.PendingAuthVerifyChallenge, err = vatsimauth.GenerateChallenge() + if err != nil { + log.Println("Error generating challenge string") + result := NewProcessorResult() + result.Disconnect(true) + return result + } + + newChallengePDU := protocol.AuthChallengePDU{ + From: protocol.ServerCallsign, + To: client.Callsign, + Challenge: client.PendingAuthVerifyChallenge, + } + + result := NewProcessorResult() + result.AddReply(challengeResponsePDU.Serialize()) + result.AddReply(newChallengePDU.Serialize()) + + return result +} + +func AuthChallengeResponseProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseAuthChallengeResponsePDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Verify the response with the stored pending challenge + challengeResponse := client.AuthVerify.GenerateResponse(client.PendingAuthVerifyChallenge) + if challengeResponse != pdu.ChallengeResponse { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.UnauthorizedSoftwareError).Serialize()) + result.Disconnect(true) + return result + } + client.AuthVerify.UpdateState(challengeResponse) + + result := NewProcessorResult() + return result +} + +func DeletePilotProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseDeletePilotPDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Check for valid CID + if pdu.CID != client.CID { + result := NewProcessorResult() + + result.AddReply(protocol.NewGenericFSDError(protocol.SyntaxError).Serialize()) + result.Disconnect(true) + + return result + } + + result := NewProcessorResult() + result.Disconnect(true) + + return result +} + +func PingProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParsePingPDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Ignore if the ping isn't for the server + if pdu.To != protocol.ServerCallsign { + result := NewProcessorResult() + return result + } + + result := NewProcessorResult() + pongPDU := protocol.PongPDU{ + From: protocol.ServerCallsign, + To: client.Callsign, + Timestamp: pdu.Timestamp, + } + result.AddReply(pongPDU.Serialize()) + + return result +} + +func BroadcastMessageProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseTextMessagePDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Ignore if the user doesn't have permission + if client.NetworkRating < protocol.NetworkRatingSUP { + return NewProcessorResult() + } + + // Verify the To field is set to `*` + if pdu.To != "*" { + return NewProcessorResult() + } + + mail := NewMail(client) + mail.SetType(MailTypeBroadcastAll) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result +} + +func WallopMessageProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseTextMessagePDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Verify the To field is set to `*S` + if pdu.To != "*S" { + return NewProcessorResult() + } + + mail := NewMail(client) + mail.SetType(MailTypeBroadcastSupervisors) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result +} + +func RadioMessageProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseTextMessagePDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Verify the To field is a radio frequency + if !strings.HasPrefix(pdu.To, "@") && len(pdu.To) != 6 { + return NewProcessorResult() + } + + mail := NewMail(client) + mail.SetType(MailTypeBroadcastRanged) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result +} + +func DirectMessageProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseTextMessagePDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + mail := NewMail(client) + mail.SetType(MailTypeDirect) + mail.AddRecipient(pdu.To) + mail.AddPacket(rawPacket) + + result := NewProcessorResult() + result.AddMail(*mail) + + return result +} + +func KillRequestProcessor(client *FSDClient, rawPacket string) *ProcessorResult { + // Parse & validate packet + pdu, err := protocol.ParseKillRequestPDU(rawPacket) + if err != nil { + var fsdError *protocol.FSDError + result := NewProcessorResult() + if errors.As(err, &fsdError) { + result.AddReply(fsdError.Serialize()) + } + result.Disconnect(true) + return result + } + + // Check for valid source callsign + if pdu.From != client.Callsign { + result := NewProcessorResult() + result.AddReply(protocol.NewGenericFSDError(protocol.PDUSourceInvalidError).Serialize()) + result.Disconnect(true) + return result + } + + // Check if client has permission + if client.NetworkRating < protocol.NetworkRatingSUP { + return NewProcessorResult() + } + + victim, err := PO.GetClient(pdu.To) + if err != nil { + result := NewProcessorResult() + if errors.Is(err, CallsignNotRegisteredError) { + result.AddReply(protocol.NewGenericFSDError(protocol.NoSuchCallsignError).Serialize()) + } + return result + } + + result := NewProcessorResult() + replyPDU := protocol.TextMessagePDU{ + From: protocol.ServerCallsign, + To: client.Callsign, + Message: "", + } + + // Attempt a non-blocking kill + select { + case victim.Kill <- rawPacket: + replyPDU.Message = "Killed " + pdu.To + default: + replyPDU.Message = "ERROR: unable to kill " + pdu.To + } + + result.AddReply(replyPDU.Serialize()) + return result +} diff --git a/protocol/add_pilot.go b/protocol/add_pilot.go new file mode 100644 index 0000000..76c54ea --- /dev/null +++ b/protocol/add_pilot.go @@ -0,0 +1,69 @@ +package protocol + +import ( + "fmt" + "strconv" + "strings" +) + +type AddPilotPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + CID int `validate:"required,min=100000,max=9999999"` + Token string `validate:"required,jwt"` + NetworkRating int `validate:"min=1,max=12"` + ProtocolRevision int `validate:""` + SimulatorType int `validate:"min=0,max=6"` + RealName string `validate:"required,max=32"` +} + +func (p *AddPilotPDU) Serialize() string { + return fmt.Sprintf("#AP%s:%s:%d:%s:%d:%d:%d:%s%s", p.From, p.To, p.CID, p.Token, p.NetworkRating, p.ProtocolRevision, p.SimulatorType, p.RealName, PacketDelimeter) +} + +func ParseAddPilotPDU(rawPacket string) (*AddPilotPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "#AP") + fields := strings.Split(rawPacket, Delimeter) + if len(fields) != 8 { + return nil, NewGenericFSDError(SyntaxError) + } + + cid, err := strconv.Atoi(fields[2]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + networkRating, err := strconv.Atoi(fields[4]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + protocolRevision, err := strconv.Atoi(fields[5]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + simulatorType, err := strconv.Atoi(fields[6]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := AddPilotPDU{ + From: fields[0], + To: fields[1], + CID: cid, + Token: fields[3], + NetworkRating: networkRating, + ProtocolRevision: protocolRevision, + SimulatorType: simulatorType, + RealName: fields[7], + } + + err = V.Struct(&pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/add_pilot_test.go b/protocol/add_pilot_test.go new file mode 100644 index 0000000..824ec54 --- /dev/null +++ b/protocol/add_pilot_test.go @@ -0,0 +1,121 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestAddPilotPDU_Serialize(t *testing.T) { + // Define test case with expected serialization output + tests := []struct { + name string + pdu AddPilotPDU + wantStr string + }{ + { + name: "Valid Serialization", + pdu: AddPilotPDU{ + From: "CLIENT", + To: "SERVER", + CID: 1234567, + Token: "jwtTokenExample", + NetworkRating: 5, + ProtocolRevision: 2, + SimulatorType: 3, + RealName: "John Smith", + }, + wantStr: "#APCLIENT:SERVER:1234567:jwtTokenExample:5:2:3:John Smith\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform serialization + gotStr := tc.pdu.Serialize() + + // Assert that the serialization output matches the expected string + assert.Equal(t, tc.wantStr, gotStr) + }) + } +} + +func TestParseAddPilotPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *AddPilotPDU + wantErr error + }{ + { + "Valid", + "#APCLIENT:SERVER:1234567:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGgudmF0c2ltLm5ldC9hcGkvZnNkLWp3dCIsInN1YiI6IjEwMDAwMDAiLCJhdWQiOlsiZnNkLWxpdmUiXSwiZXhwIjoxNzExOTA4OTM4LCJuYmYiOjE3MTE5MDgzOTgsImlhdCI6MTcxMTkwODUxOCwianRpIjoiRDFCS1BPdUdKelAzZE5NdnV6d1JNZz09IiwiY29udHJvbGxlcl9yYXRpbmciOjAsInBpbG90X3JhdGluZyI6MH0.kg23HhANM6aUI9mRUUGX-Vx8HKjTpzkDxOXlvWkjnC8:5:2:3:John Smith\r\n", + &AddPilotPDU{ + From: "CLIENT", + To: "SERVER", + CID: 1234567, + Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGgudmF0c2ltLm5ldC9hcGkvZnNkLWp3dCIsInN1YiI6IjEwMDAwMDAiLCJhdWQiOlsiZnNkLWxpdmUiXSwiZXhwIjoxNzExOTA4OTM4LCJuYmYiOjE3MTE5MDgzOTgsImlhdCI6MTcxMTkwODUxOCwianRpIjoiRDFCS1BPdUdKelAzZE5NdnV6d1JNZz09IiwiY29udHJvbGxlcl9yYXRpbmciOjAsInBpbG90X3JhdGluZyI6MH0.kg23HhANM6aUI9mRUUGX-Vx8HKjTpzkDxOXlvWkjnC8", + NetworkRating: 5, + ProtocolRevision: 2, + SimulatorType: 3, + RealName: "John Smith", + }, + nil, + }, + { + "Invalid CID length", + "#APCLIENT:SERVER:12345:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGgudmF0c2ltLm5ldC9hcGkvZnNkLWp3dCIsInN1YiI6IjEwMDAwMDAiLCJhdWQiOlsiZnNkLWxpdmUiXSwiZXhwIjoxNzExOTA4OTM4LCJuYmYiOjE3MTE5MDgzOTgsImlhdCI6MTcxMTkwODUxOCwianRpIjoiRDFCS1BPdUdKelAzZE5NdnV6d1JNZz09IiwiY29udHJvbGxlcl9yYXRpbmciOjAsInBpbG90X3JhdGluZyI6MH0.kg23HhANM6aUI9mRUUGX-Vx8HKjTpzkDxOXlvWkjnC8:5:2:3:John Smith\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid CID format", + "#APCLIENT:SERVER:ABCDE12:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGgudmF0c2ltLm5ldC9hcGkvZnNkLWp3dCIsInN1YiI6IjEwMDAwMDAiLCJhdWQiOlsiZnNkLWxpdmUiXSwiZXhwIjoxNzExOTA4OTM4LCJuYmYiOjE3MTE5MDgzOTgsImlhdCI6MTcxMTkwODUxOCwianRpIjoiRDFCS1BPdUdKelAzZE5NdnV6d1JNZz09IiwiY29udHJvbGxlcl9yYXRpbmciOjAsInBpbG90X3JhdGluZyI6MH0.kg23HhANM6aUI9mRUUGX-Vx8HKjTpzkDxOXlvWkjnC8:5:2:3:John Smith\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid Network Rating", + "#APCLIENT:SERVER:1234567:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGgudmF0c2ltLm5ldC9hcGkvZnNkLWp3dCIsInN1YiI6IjEwMDAwMDAiLCJhdWQiOlsiZnNkLWxpdmUiXSwiZXhwIjoxNzExOTA4OTM4LCJuYmYiOjE3MTE5MDgzOTgsImlhdCI6MTcxMTkwODUxOCwianRpIjoiRDFCS1BPdUdKelAzZE5NdnV6d1JNZz09IiwiY29udHJvbGxlcl9yYXRpbmciOjAsInBpbG90X3JhdGluZyI6MH0.kg23HhANM6aUI9mRUUGX-Vx8HKjTpzkDxOXlvWkjnC8:13:2:3:John Smith\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid Simulator Type", + "#APCLIENT:SERVER:1234567:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGgudmF0c2ltLm5ldC9hcGkvZnNkLWp3dCIsInN1YiI6IjEwMDAwMDAiLCJhdWQiOlsiZnNkLWxpdmUiXSwiZXhwIjoxNzExOTA4OTM4LCJuYmYiOjE3MTE5MDgzOTgsImlhdCI6MTcxMTkwODUxOCwianRpIjoiRDFCS1BPdUdKelAzZE5NdnV6d1JNZz09IiwiY29udHJvbGxlcl9yYXRpbmciOjAsInBpbG90X3JhdGluZyI6MH0.kg23HhANM6aUI9mRUUGX-Vx8HKjTpzkDxOXlvWkjnC8:5:2:7:John Smith\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Real name too long", + "#APCLIENT:SERVER:1234567:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGgudmF0c2ltLm5ldC9hcGkvZnNkLWp3dCIsInN1YiI6IjEwMDAwMDAiLCJhdWQiOlsiZnNkLWxpdmUiXSwiZXhwIjoxNzExOTA4OTM4LCJuYmYiOjE3MTE5MDgzOTgsImlhdCI6MTcxMTkwODUxOCwianRpIjoiRDFCS1BPdUdKelAzZE5NdnV6d1JNZz09IiwiY29udHJvbGxlcl9yYXRpbmciOjAsInBpbG90X3JhdGluZyI6MH0.kg23HhANM6aUI9mRUUGX-Vx8HKjTpzkDxOXlvWkjnC8:5:2:3:This Is Way Too Long For A Full Real Name\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing Delimeters", + "APCLIENTSERVER1234567eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGgudmF0c2ltLm5ldC9hcGkvZnNkLWp3dCIsInN1YiI6IjEwMDAwMDAiLCJhdWQiOlsiZnNkLWxpdmUiXSwiZXhwIjoxNzExOTA4OTM4LCJuYmYiOjE3MTE5MDgzOTgsImlhdCI6MTcxMTkwODUxOCwianRpIjoiRDFCS1BPdUdKelAzZE5NdnV6d1JNZz09IiwiY29udHJvbGxlcl9yYXRpbmciOjAsInBpbG90X3JhdGluZyI6MH0.kg23HhANM6aUI9mRUUGX-Vx8HKjTpzkDxOXlvWkjnC8523John Smith", + nil, + NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseAddPilotPDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/auth_challenge.go b/protocol/auth_challenge.go new file mode 100644 index 0000000..31a78da --- /dev/null +++ b/protocol/auth_challenge.go @@ -0,0 +1,39 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type AuthChallengePDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + Challenge string `validate:"required,hexadecimal,min=4,max=32"` +} + +func (p *AuthChallengePDU) Serialize() string { + return fmt.Sprintf("$ZC%s:%s:%s%s", p.From, p.To, p.Challenge, PacketDelimeter) +} + +func ParseAuthChallengePDU(rawPacket string) (*AuthChallengePDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$ZC") + fields := strings.Split(rawPacket, Delimeter) + + if len(fields) != 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := AuthChallengePDU{ + From: fields[0], + To: fields[1], + Challenge: fields[2], + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/auth_challenge_response.go b/protocol/auth_challenge_response.go new file mode 100644 index 0000000..f1728e0 --- /dev/null +++ b/protocol/auth_challenge_response.go @@ -0,0 +1,39 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type AuthChallengeResponsePDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + ChallengeResponse string `validate:"required,hexadecimal,md5"` +} + +func (p *AuthChallengeResponsePDU) Serialize() string { + return fmt.Sprintf("$ZR%s:%s:%s%s", p.From, p.To, p.ChallengeResponse, PacketDelimeter) +} + +func ParseAuthChallengeResponsePDU(rawPacket string) (*AuthChallengeResponsePDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$ZR") + fields := strings.Split(rawPacket, Delimeter) + + if len(fields) != 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := AuthChallengeResponsePDU{ + From: fields[0], + To: fields[1], + ChallengeResponse: fields[2], + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/auth_challenge_response_test.go b/protocol/auth_challenge_response_test.go new file mode 100644 index 0000000..a705009 --- /dev/null +++ b/protocol/auth_challenge_response_test.go @@ -0,0 +1,89 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestAuthChallengeResponsePDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu *AuthChallengeResponsePDU + want string + }{ + { + name: "Valid Serialization", + pdu: &AuthChallengeResponsePDU{ + From: "SERVER", + To: "CLIENT", + ChallengeResponse: "0c4a96fa1cab961018620f120988cdf9", + }, + want: "$ZRSERVER:CLIENT:0c4a96fa1cab961018620f120988cdf9\r\n", + }, + } + + V = validator.New() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.pdu.Serialize() + assert.Equal(t, test.want, result) + }) + } +} + +func TestParseAuthChallengeResponsePDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *AuthChallengeResponsePDU + wantErr error + }{ + { + name: "Valid Packet", + packet: "$ZRSERVER:CLIENT:0c4a96fa1cab961018620f120988cdf9\r\n", + want: &AuthChallengeResponsePDU{ + From: "SERVER", + To: "CLIENT", + ChallengeResponse: "0c4a96fa1cab961018620f120988cdf9", + }, + wantErr: nil, + }, + { + name: "Missing From Field", + packet: "$ZR:CLIENT:0c4a96fa1cab961018620f120988cdf9\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Missing To Field", + packet: "$ZRSERVER::0c4a96fa1cab961018620f120988cdf9\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Challenge response not md5", + packet: "$ZRSERVER:CLIENT:0c4a96fa1cab9610\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid Hexadecimal", + packet: "$ZRSERVER:CLIENT:ghij7890\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + } + + V = validator.New() + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := ParseAuthChallengeResponsePDU(tc.packet) + assert.Equal(t, tc.want, result) + assert.Equal(t, tc.wantErr, err) + }) + } +} diff --git a/protocol/auth_challenge_test.go b/protocol/auth_challenge_test.go new file mode 100644 index 0000000..7ebfce3 --- /dev/null +++ b/protocol/auth_challenge_test.go @@ -0,0 +1,89 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestAuthChallengePDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu *AuthChallengePDU + want string + }{ + { + name: "Valid Serialization", + pdu: &AuthChallengePDU{ + From: "SERVER", + To: "CLIENT", + Challenge: "abcd1234ef", + }, + want: "$ZCSERVER:CLIENT:abcd1234ef\r\n", + }, + } + + V = validator.New() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.pdu.Serialize() + assert.Equal(t, test.want, result) + }) + } +} + +func TestParseAuthChallengePDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *AuthChallengePDU + wantErr error + }{ + { + name: "Valid Packet", + packet: "$ZCSERVER:CLIENT:abcd1234ef\r\n", + want: &AuthChallengePDU{ + From: "SERVER", + To: "CLIENT", + Challenge: "abcd1234ef", + }, + wantErr: nil, + }, + { + name: "Missing From Field", + packet: "$ZC:CLIENT:abcd1234ef\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Missing To Field", + packet: "$ZCSERVER::abcd1234ef\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid Challenge Length", + packet: "$ZCSERVER:CLIENT:ab\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid Hexadecimal Challenge", + packet: "$ZCSERVER:CLIENT:ghij7890\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + } + + V = validator.New() + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := ParseAuthChallengePDU(tc.packet) + assert.Equal(t, tc.want, result) + assert.Equal(t, tc.wantErr, err) + }) + } +} diff --git a/protocol/client_identification.go b/protocol/client_identification.go new file mode 100644 index 0000000..e6c5e9a --- /dev/null +++ b/protocol/client_identification.go @@ -0,0 +1,87 @@ +package protocol + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "strconv" + "strings" +) + +type ClientIdentificationPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + ClientID uint16 `validate:"required"` + ClientName string `validate:"required,alphanum,max=16"` + MajorVersion int `validate:""` + MinorVersion int `validate:""` + CID int `validate:"required,min=100000,max=9999999"` + SysUID int `validate:"required,number"` + InitialChallenge string `validate:"required,hexadecimal,min=2,max=32"` +} + +func (p *ClientIdentificationPDU) Serialize() string { + clientIDBytes := make([]byte, 2) + binary.BigEndian.PutUint16(clientIDBytes, p.ClientID) + clientIDStr := hex.EncodeToString(clientIDBytes) + return fmt.Sprintf("$ID%s:%s:%s:%s:%d:%d:%d:%d:%s%s", p.From, p.To, clientIDStr, p.ClientName, p.MajorVersion, p.MinorVersion, p.CID, p.SysUID, p.InitialChallenge, PacketDelimeter) +} + +func ParseClientIdentificationPDU(rawPacket string) (*ClientIdentificationPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$ID") + fields := strings.Split(rawPacket, Delimeter) + if len(fields) != 9 { + return nil, NewGenericFSDError(SyntaxError) + } + + // fields[2] == uint16 in hexadecimal + if len(fields[2]) != 4 { + return nil, NewGenericFSDError(SyntaxError) + } + + clientIDBytes, err := hex.DecodeString(fields[2]) + if err != nil || len(clientIDBytes) != 2 { + return nil, NewGenericFSDError(SyntaxError) + } + clientID := binary.BigEndian.Uint16(clientIDBytes) + + majorVersion, err := strconv.Atoi(fields[4]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + minorVersion, err := strconv.Atoi(fields[5]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + cid, err := strconv.Atoi(fields[6]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + sysUID, err := strconv.Atoi(fields[7]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := ClientIdentificationPDU{ + From: fields[0], + To: fields[1], + ClientID: clientID, + ClientName: fields[3], + MajorVersion: majorVersion, + MinorVersion: minorVersion, + CID: cid, + SysUID: sysUID, + InitialChallenge: fields[8], + } + + err = V.Struct(&pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/client_identification_test.go b/protocol/client_identification_test.go new file mode 100644 index 0000000..cd177ed --- /dev/null +++ b/protocol/client_identification_test.go @@ -0,0 +1,141 @@ +package protocol + +import ( + "encoding/hex" + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseClientIdentificationPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *ClientIdentificationPDU + wantErr error + }{ + { + name: "Valid", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:2:1234567:12345678:abcd1234\r\n", + want: &ClientIdentificationPDU{ + From: "CLIENT", + To: "SERVER", + ClientID: 0x1234, + ClientName: "ClientName", + MajorVersion: 1, + MinorVersion: 2, + CID: 1234567, + SysUID: 12345678, + InitialChallenge: "abcd1234", + }, + wantErr: nil, + }, + { + name: "Missing Field", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:2:1234567:abcd1234\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid Major Version", + packet: "$IDCLIENT:SERVER:1234:ClientName:X:2:0001234:12345678:abcd1234\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid ClientID", + packet: "$IDCLIENT:SERVER:12:ClientName:1:2:0001234:12345678:abcd1234\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid Hexadecimal in Initial Challenge", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:2:0001234:12345678:xyz\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Initial Challenge Too Short", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:2:0001234:12345678:a\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid Minor Version", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:XY:0001234:12345678:abcd1234\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid SysUID - Non-numeric", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:2:0001234:SYSUID:abcd1234\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid CID - too long", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:2:0001234567890:12345678:abcd1234\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid CID - contains letters", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:2:000ABCD:12345678:abcd1234\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Initial Challenge Too Long", + packet: "$IDCLIENT:SERVER:1234:ClientName:1:2:0001234:12345678:" + hex.EncodeToString(make([]byte, 33)) + "\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := ParseClientIdentificationPDU(tc.packet) + + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tc.want, result) + }) + } +} + +func TestClientIdentificationPDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu ClientIdentificationPDU + want string + }{ + { + name: "Serialize Valid PDU", + pdu: ClientIdentificationPDU{ + From: "CLIENT", + To: "SERVER", + ClientID: 0x1234, + ClientName: "ClientName", + MajorVersion: 1, + MinorVersion: 2, + CID: 1234567, + SysUID: 12345678, + InitialChallenge: "abcd1234", + }, + want: "$IDCLIENT:SERVER:1234:ClientName:1:2:1234567:12345678:abcd1234\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := tc.pdu.Serialize() + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/client_query.go b/protocol/client_query.go new file mode 100644 index 0000000..2ad94b5 --- /dev/null +++ b/protocol/client_query.go @@ -0,0 +1,64 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type ClientQueryPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,max=7"` + QueryType string `validate:"required,ascii,min=2,max=7"` + Payload string `validate:""` +} + +func (p *ClientQueryPDU) Serialize() string { + if p.Payload == "" { + return fmt.Sprintf("$CQ%s:%s:%s%s", p.From, p.To, p.QueryType, PacketDelimeter) + } else { + return fmt.Sprintf("$CQ%s:%s:%s:%s%s", p.From, p.To, p.QueryType, p.Payload, PacketDelimeter) + } +} + +func ParseClientQueryPDU(rawPacket string) (*ClientQueryPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$CQ") + fields := strings.SplitN(rawPacket, Delimeter, 4) + if len(fields) < 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + payload := "" + if len(fields) == 4 { + payload = fields[3] + } + + pdu := ClientQueryPDU{ + From: fields[0], + To: fields[1], + QueryType: fields[2], + Payload: payload, + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + switch pdu.QueryType { + case "ATC", "CAPS", "C?", + "RN", "SV", "ATIS", + "IP", "INF", "FP", + "IPC", "BY", "HI", + "HLP", "NOHLP", "WH", + "IT", "HT", "DR", + "FA", "TA", "BC", + "SC", "VT", "ACC", + "NEWINFO", "NEWATIS", "EST", + "GD": + default: + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/client_query_response.go b/protocol/client_query_response.go new file mode 100644 index 0000000..4191cb6 --- /dev/null +++ b/protocol/client_query_response.go @@ -0,0 +1,55 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type ClientQueryResponsePDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,max=7"` + QueryType string `validate:"required,ascii,min=2,max=7"` + Payload string `validate:""` +} + +func (p *ClientQueryResponsePDU) Serialize() string { + if p.Payload == "" { + return fmt.Sprintf("$CR%s:%s:%s%s", p.From, p.To, p.QueryType, PacketDelimeter) + } else { + return fmt.Sprintf("$CR%s:%s:%s:%s%s", p.From, p.To, p.QueryType, p.Payload, PacketDelimeter) + } +} + +func ParseClientQueryResponsePDU(rawPacket string) (*ClientQueryResponsePDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$CR") + fields := strings.SplitN(rawPacket, Delimeter, 4) + if len(fields) < 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + payload := "" + if len(fields) == 4 { + payload = fields[3] + } + + pdu := ClientQueryResponsePDU{ + From: fields[0], + To: fields[1], + QueryType: fields[2], + Payload: payload, + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + switch pdu.QueryType { + case "ATC", "CAPS", "C?", "RN", "SV", "ATIS", "IP", "INF", "FP", "IPC", "BY", "HI", "HLP", "NOHLP", "WH", "IT", "HT", "DR", "FA", "TA", "BC", "SC", "VT", "ACC", "NEWINFO", "NEWATIS", "EST", "GD": + default: + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/client_query_response_test.go b/protocol/client_query_response_test.go new file mode 100644 index 0000000..90afbe5 --- /dev/null +++ b/protocol/client_query_response_test.go @@ -0,0 +1,157 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseClientQueryResponsePDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *ClientQueryResponsePDU + wantErr error + }{ + { + "Valid with Payload", + "$CRFROM:TO:FP:P12345\r\n", + &ClientQueryResponsePDU{ + From: "FROM", + To: "TO", + QueryType: "FP", + Payload: "P12345", + }, + nil, + }, + { + "Valid without Payload", + "$CRFROM:TO:WH\r\n", + &ClientQueryResponsePDU{ + From: "FROM", + To: "TO", + QueryType: "WH", + Payload: "", + }, + nil, + }, + { + "Invalid QueryType", + "$CRFROM:TO:XYZ\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "From field too long", + "$CRFROMFROM:TO:WH\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing QueryType", + "$CRFROM:TO:\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Empty packet", + "\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing From field", + "$CR:TO:WH\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseClientQueryResponsePDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestClientQueryResponsePDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu *ClientQueryResponsePDU + want string + }{ + { + name: "Serialize with payload", + pdu: &ClientQueryResponsePDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "IP", + Payload: "192.0.2.1", + }, + want: "$CRCLIENT:SERVER:IP:192.0.2.1\r\n", + }, + { + name: "Serialize without payload", + pdu: &ClientQueryResponsePDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "IP", + Payload: "", + }, + want: "$CRCLIENT:SERVER:IP\r\n", + }, + { + name: "Serialize with minimum query type length", + pdu: &ClientQueryResponsePDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "C?", + Payload: "Question", + }, + want: "$CRCLIENT:SERVER:C?:Question\r\n", + }, + { + name: "Serialize with maximum query type length", + pdu: &ClientQueryResponsePDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "NEWINFO", + Payload: "Updates", + }, + want: "$CRCLIENT:SERVER:NEWINFO:Updates\r\n", + }, + { + name: "Serialize with non-alphanumeric characters in payload", + pdu: &ClientQueryResponsePDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "FP", + Payload: "FPL-ABCDE-IS -KJFK0740-N0460F330 BETTE4 BETTE ACK J174 LICKS N76B MICYY SH9", + }, + want: "$CRCLIENT:SERVER:FP:FPL-ABCDE-IS -KJFK0740-N0460F330 BETTE4 BETTE ACK J174 LICKS N76B MICYY SH9\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform serialization + result := tc.pdu.Serialize() + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/client_query_test.go b/protocol/client_query_test.go new file mode 100644 index 0000000..3e4045a --- /dev/null +++ b/protocol/client_query_test.go @@ -0,0 +1,157 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseClientQueryPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *ClientQueryPDU + wantErr error + }{ + { + "Valid with Payload", + "$CQFROM:TO:FP:P12345\r\n", + &ClientQueryPDU{ + From: "FROM", + To: "TO", + QueryType: "FP", + Payload: "P12345", + }, + nil, + }, + { + "Valid without Payload", + "$CQFROM:TO:WH\r\n", + &ClientQueryPDU{ + From: "FROM", + To: "TO", + QueryType: "WH", + Payload: "", + }, + nil, + }, + { + "Invalid QueryType", + "$CQFROM:TO:XYZ\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "From field too long", + "$CQFROMFROM:TO:WH\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing QueryType", + "$CQFROM:TO:\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Empty packet", + "\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing From field", + "$CQ:TO:WH\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseClientQueryPDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestClientQueryPDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu *ClientQueryPDU + want string + }{ + { + name: "Serialize with payload", + pdu: &ClientQueryPDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "IP", + Payload: "192.0.2.1", + }, + want: "$CQCLIENT:SERVER:IP:192.0.2.1\r\n", + }, + { + name: "Serialize without payload", + pdu: &ClientQueryPDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "IP", + Payload: "", + }, + want: "$CQCLIENT:SERVER:IP\r\n", + }, + { + name: "Serialize with minimum query type length", + pdu: &ClientQueryPDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "C?", + Payload: "Question", + }, + want: "$CQCLIENT:SERVER:C?:Question\r\n", + }, + { + name: "Serialize with maximum query type length", + pdu: &ClientQueryPDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "NEWINFO", + Payload: "Updates", + }, + want: "$CQCLIENT:SERVER:NEWINFO:Updates\r\n", + }, + { + name: "Serialize with non-alphanumeric characters in payload", + pdu: &ClientQueryPDU{ + From: "CLIENT", + To: "SERVER", + QueryType: "FP", + Payload: "FPL-ABCDE-IS -KJFK0740-N0460F330 BETTE4 BETTE ACK J174 LICKS N76B MICYY SH9", + }, + want: "$CQCLIENT:SERVER:FP:FPL-ABCDE-IS -KJFK0740-N0460F330 BETTE4 BETTE ACK J174 LICKS N76B MICYY SH9\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform serialization + result := tc.pdu.Serialize() + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/delete_pilot.go b/protocol/delete_pilot.go new file mode 100644 index 0000000..6a96e35 --- /dev/null +++ b/protocol/delete_pilot.go @@ -0,0 +1,43 @@ +package protocol + +import ( + "fmt" + "strconv" + "strings" +) + +type DeletePilotPDU struct { + From string `validate:"required,alphanum,max=7"` + CID int `validate:"required,min=100000,max=9999999"` +} + +func (p *DeletePilotPDU) Serialize() string { + return fmt.Sprintf("#DP%s:%d%s", p.From, p.CID, PacketDelimeter) +} + +func ParseDeletePilotPDU(rawPacket string) (*DeletePilotPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "#DP") + fields := strings.Split(rawPacket, Delimeter) + + if len(fields) != 2 { + return nil, NewGenericFSDError(SyntaxError) + } + + cid, err := strconv.Atoi(fields[1]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := DeletePilotPDU{ + From: fields[0], + CID: cid, + } + + err = V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/delete_pilot_test.go b/protocol/delete_pilot_test.go new file mode 100644 index 0000000..83d1be7 --- /dev/null +++ b/protocol/delete_pilot_test.go @@ -0,0 +1,104 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseDeletePilotPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *DeletePilotPDU + wantErr error + }{ + { + "Valid Packet", + "#DPCTRLLL:1234567\r\n", + &DeletePilotPDU{ + From: "CTRLLL", + CID: 1234567, + }, + nil, + }, + { + "Invalid From Field (Too Long)", + "#DPCONTROLLER1:1234567\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid CID Field (Non-Numeric)", + "#DPCTRLLL:ABCDEF1\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid CID Field (Wrong Length)", + "#DPCTRLLL:12345\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Extra Fields", + "#DPCTRLLL:1234567:ExtraData\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing CID", + "#DPCTRLLL:\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseDeletePilotPDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestDeletePilotPDU_Serialize(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + pdu DeletePilotPDU + want string + }{ + { + name: "Valid Serialization", + pdu: DeletePilotPDU{ + From: "N123", + CID: 7654321, + }, + want: "#DPN123:7654321\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform serialization + result := tc.pdu.Serialize() + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/fast_pilot_position.go b/protocol/fast_pilot_position.go new file mode 100644 index 0000000..9be070b --- /dev/null +++ b/protocol/fast_pilot_position.go @@ -0,0 +1,178 @@ +package protocol + +import ( + "fmt" + "strconv" + "strings" +) + +const ( + FastPilotPositionTypeFast = iota + FastPilotPositionTypeSlow + FastPilotPositionTypeStopped +) + +type VelocityVector struct { + X float64 `validate:"min=-9999.0,max=9999.0"` + Y float64 `validate:"min=-9999.0,max=9999.0"` + Z float64 `validate:"min=-9999.0,max=9999.0"` +} + +type FastPilotPositionPDU struct { + Type int `validate:"min=0,max=2"` + From string `validate:"required,alphanum,max=7"` + Lat float64 `validate:"min=-90.0,max=90.0"` + Lng float64 `validate:"min=-180.0,max=180.0"` + AltitudeTrue float64 `validate:"min=-1500.0,max=99999.0"` + AltitudeAgl float64 `validate:"min=-1500.0,max=99999.0"` + Pitch float64 `validate:"min=-360.0,max=360.0"` + Heading float64 `validate:"min=-360.0,max=360.0"` + Bank float64 `validate:"min=-360.0,max=360.0"` + PositionalVelocityVector VelocityVector `validate:""` // meters per second + RotationalVelocityVector VelocityVector `validate:""` // radians per second + NoseGearAngle float64 `validate:"min=-360.0,max=360.0"` +} + +func (p *FastPilotPositionPDU) Serialize() string { + switch p.Type { + case FastPilotPositionTypeFast: + return fmt.Sprintf("^%s:%.6f:%.6f:%.2f:%.2f:%d:%.4f:%.4f:%.4f:%.4f:%.4f:%.4f:%.2f%s", p.From, p.Lat, p.Lng, p.AltitudeTrue, p.AltitudeAgl, packPitchBankHeading(p.Pitch, p.Bank, p.Heading), p.PositionalVelocityVector.X, p.PositionalVelocityVector.Y, p.PositionalVelocityVector.Z, p.RotationalVelocityVector.X, p.RotationalVelocityVector.Y, p.RotationalVelocityVector.Z, p.NoseGearAngle, PacketDelimeter) + case FastPilotPositionTypeSlow: + return fmt.Sprintf("#SL%s:%.6f:%.6f:%.2f:%.2f:%d:%.4f:%.4f:%.4f:%.4f:%.4f:%.4f:%.2f%s", p.From, p.Lat, p.Lng, p.AltitudeTrue, p.AltitudeAgl, packPitchBankHeading(p.Pitch, p.Bank, p.Heading), p.PositionalVelocityVector.X, p.PositionalVelocityVector.Y, p.PositionalVelocityVector.Z, p.RotationalVelocityVector.X, p.RotationalVelocityVector.Y, p.RotationalVelocityVector.Z, p.NoseGearAngle, PacketDelimeter) + default: // FastPilotPositionTypeStopped + return fmt.Sprintf("#ST%s:%.6f:%.6f:%.2f:%.2f:%d:%.2f%s", p.From, p.Lat, p.Lng, p.AltitudeTrue, p.AltitudeAgl, packPitchBankHeading(p.Pitch, p.Bank, p.Heading), p.NoseGearAngle, PacketDelimeter) + } +} + +func ParseFastPilotPositionPDU(rawPacket string) (*FastPilotPositionPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + + var pduType int + if strings.HasPrefix(rawPacket, "^") { + pduType = FastPilotPositionTypeFast + rawPacket = strings.TrimPrefix(rawPacket, "^") + } else if strings.HasPrefix(rawPacket, "#SL") { + pduType = FastPilotPositionTypeSlow + rawPacket = strings.TrimPrefix(rawPacket, "#SL") + } else if strings.HasPrefix(rawPacket, "#ST") { + pduType = FastPilotPositionTypeStopped + rawPacket = strings.TrimPrefix(rawPacket, "#ST") + } else { + return nil, NewGenericFSDError(SyntaxError) + } + + fields := strings.Split(rawPacket, Delimeter) + + if pduType == FastPilotPositionTypeStopped { + if len(fields) != 7 { + return nil, NewGenericFSDError(SyntaxError) + } + } else { + if len(fields) != 13 { + return nil, NewGenericFSDError(SyntaxError) + } + } + + lat, err := strconv.ParseFloat(fields[1], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + lng, err := strconv.ParseFloat(fields[2], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + altTrue, err := strconv.ParseFloat(fields[3], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + altAgl, err := strconv.ParseFloat(fields[4], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + pbh, err := strconv.ParseUint(fields[5], 10, 32) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + pitch, bank, heading := unpackPitchBankHeading(uint32(pbh)) + + var positionalVector VelocityVector + var rotationalVector VelocityVector + var noseGearAngle float64 + if pduType == FastPilotPositionTypeStopped { + positionalVector = VelocityVector{ + X: 0, + Y: 0, + Z: 0, + } + rotationalVector = VelocityVector{ + X: 0, + Y: 0, + Z: 0, + } + noseGearAngle, err = strconv.ParseFloat(fields[6], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + } else { + positionalVector.X, err = strconv.ParseFloat(fields[6], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + positionalVector.Y, err = strconv.ParseFloat(fields[7], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + positionalVector.Z, err = strconv.ParseFloat(fields[8], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + rotationalVector.X, err = strconv.ParseFloat(fields[9], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + rotationalVector.Y, err = strconv.ParseFloat(fields[10], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + rotationalVector.Z, err = strconv.ParseFloat(fields[11], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + noseGearAngle, err = strconv.ParseFloat(fields[12], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + } + + pdu := FastPilotPositionPDU{ + Type: pduType, + From: fields[0], + Lat: lat, + Lng: lng, + AltitudeTrue: altTrue, + AltitudeAgl: altAgl, + Pitch: pitch, + Heading: heading, + Bank: bank, + PositionalVelocityVector: positionalVector, + RotationalVelocityVector: rotationalVector, + NoseGearAngle: noseGearAngle, + } + + err = V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/fast_pilot_position_test.go b/protocol/fast_pilot_position_test.go new file mode 100644 index 0000000..e99c07b --- /dev/null +++ b/protocol/fast_pilot_position_test.go @@ -0,0 +1,256 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseFastPilotPositionPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *FastPilotPositionPDU + wantErr error + }{ + { + "Valid Fast Type", + "^PILOT:12.345678:98.765432:300.00:50.00:4177408112:123.4567:345.6789:-234.5678:111.2222:-333.4444:555.6666:90.00\r\n", + &FastPilotPositionPDU{ + Type: FastPilotPositionTypeFast, + From: "PILOT", + Lat: 12.345678, + Lng: 98.765432, + AltitudeTrue: 300, + AltitudeAgl: 50, + Pitch: 10, // Assuming appropriate conversion function + Heading: 10, + Bank: 10, + PositionalVelocityVector: VelocityVector{ + X: 123.4567, + Y: 345.6789, + Z: -234.5678, + }, + RotationalVelocityVector: VelocityVector{ + X: 111.2222, + Y: -333.4444, + Z: 555.6666, + }, + NoseGearAngle: 90.00, + }, + nil, + }, + { + "Valid Slow Type", + "#SLPILOT:12.345678:98.765432:300.00:50.00:4177408112:123.4567:345.6789:-234.5678:111.2222:-333.4444:555.6666:90.00\r\n", + &FastPilotPositionPDU{ + Type: FastPilotPositionTypeSlow, + From: "PILOT", + Lat: 12.345678, + Lng: 98.765432, + AltitudeTrue: 300, + AltitudeAgl: 50, + Pitch: 10, // Assuming appropriate conversion function + Heading: 10, + Bank: 10, + PositionalVelocityVector: VelocityVector{ + X: 123.4567, + Y: 345.6789, + Z: -234.5678, + }, + RotationalVelocityVector: VelocityVector{ + X: 111.2222, + Y: -333.4444, + Z: 555.6666, + }, + NoseGearAngle: 90.00, + }, + nil, + }, + { + "Invalid Type", + "?PILOT:12.345678:98.765432:300.00:50.00:4177408112:123.4567:345.6789:-234.5678:111.2222:-333.4444:555.6666:90.00\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Out of range value", + "^PILOT:92.000000:678.765432:300.00:50.00:4177408112:123.4567:345.6789:-234.5678:111.2222:-333.4444:555.6666:90.00\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing field", + "^PILOT:12.345678:98.765432:300.00:50.00:4177408112:123.4567:345.6789:-234.5678:111.2222:-333.4444\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Mismatched type with fields", + "#STPILOT:12.345678:98.765432:300.00:50.00:4177408112:123.4567:345.6789:-234.5678:111.2222:-333.4444:555.6666:90.00\n\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Valid Stopped Type", + "#STPILOT:12.345678:98.765432:300.00:50.00:4177408112:0.00\r\n", + &FastPilotPositionPDU{ + Type: FastPilotPositionTypeStopped, + From: "PILOT", + Lat: 12.345678, + Lng: 98.765432, + AltitudeTrue: 300, + AltitudeAgl: 50, + Pitch: 10, + Heading: 10, + Bank: 10, + PositionalVelocityVector: VelocityVector{ + X: 0, + Y: 0, + Z: 0, + }, + RotationalVelocityVector: VelocityVector{ + X: 0, + Y: 0, + Z: 0, + }, + NoseGearAngle: 0.00, + }, + nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseFastPilotPositionPDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + if tc.want != nil { + assert.Equal(t, tc.want.From, result.From) + assert.InDelta(t, tc.want.Lat, result.Lat, 1e-6) + assert.InDelta(t, tc.want.Lng, result.Lng, 1e-6) + assert.InDelta(t, tc.want.AltitudeTrue, result.AltitudeTrue, 1e-2) + assert.InDelta(t, tc.want.AltitudeAgl, result.AltitudeAgl, 1e-2) + assert.InDelta(t, tc.want.Pitch, result.Pitch, 1) + assert.InDelta(t, tc.want.Bank, result.Bank, 1) + assert.InDelta(t, tc.want.Heading, result.Heading, 1) + assert.InDelta(t, tc.want.PositionalVelocityVector.X, result.PositionalVelocityVector.X, 1e-4) + assert.InDelta(t, tc.want.PositionalVelocityVector.Y, result.PositionalVelocityVector.Y, 1e-4) + assert.InDelta(t, tc.want.PositionalVelocityVector.Z, result.PositionalVelocityVector.Z, 1e-4) + assert.InDelta(t, tc.want.RotationalVelocityVector.X, result.RotationalVelocityVector.X, 1e-4) + assert.InDelta(t, tc.want.RotationalVelocityVector.Y, result.RotationalVelocityVector.Y, 1e-4) + assert.InDelta(t, tc.want.RotationalVelocityVector.Z, result.RotationalVelocityVector.Z, 1e-4) + assert.InDelta(t, tc.want.NoseGearAngle, result.NoseGearAngle, 1e-2) + } else { + assert.Nil(t, result) + } + }) + } +} + +func TestFastPilotPositionPDUSerialize(t *testing.T) { + tests := []struct { + name string + pdu *FastPilotPositionPDU + want string + }{ + { + "Serialize Fast", + &FastPilotPositionPDU{ + Type: FastPilotPositionTypeFast, + From: "PILOT", + Lat: 12.345678, + Lng: 98.765432, + AltitudeTrue: 300, + AltitudeAgl: 50, + Pitch: 10, + Heading: 10, + Bank: 10, + PositionalVelocityVector: VelocityVector{ + X: 123.4567, + Y: 345.6789, + Z: -234.5678, + }, + RotationalVelocityVector: VelocityVector{ + X: 111.2222, + Y: -333.4444, + Z: 555.6666, + }, + NoseGearAngle: 90.00, + }, + "^PILOT:12.345678:98.765432:300.00:50.00:4177408112:123.4567:345.6789:-234.5678:111.2222:-333.4444:555.6666:90.00\r\n", + }, + { + "Serialize Slow", + &FastPilotPositionPDU{ + Type: FastPilotPositionTypeSlow, + From: "PILOT", + Lat: 12.345678, + Lng: 98.765432, + AltitudeTrue: 300, + AltitudeAgl: 50, + Pitch: 10, + Heading: 10, + Bank: 10, + PositionalVelocityVector: VelocityVector{ + X: 123.4567, + Y: 345.6789, + Z: -234.5678, + }, + RotationalVelocityVector: VelocityVector{ + X: 111.2222, + Y: -333.4444, + Z: 555.6666, + }, + NoseGearAngle: 90.00, + }, + "#SLPILOT:12.345678:98.765432:300.00:50.00:4177408112:123.4567:345.6789:-234.5678:111.2222:-333.4444:555.6666:90.00\r\n", + }, + { + "Serialize Stopped", + &FastPilotPositionPDU{ + Type: FastPilotPositionTypeStopped, + From: "PILOT", + Lat: 12.345678, + Lng: 98.765432, + AltitudeTrue: 300, + AltitudeAgl: 50, + Pitch: 10, + Heading: 10, + Bank: 10, + PositionalVelocityVector: VelocityVector{ + X: 0, + Y: 0, + Z: 0, + }, + RotationalVelocityVector: VelocityVector{ + X: 0, + Y: 0, + Z: 0, + }, + NoseGearAngle: 90.00, + }, + "#STPILOT:12.345678:98.765432:300.00:50.00:4177408112:90.00\r\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Perform serialization + got := tt.pdu.Serialize() + + // Verify result + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/protocol/fsd_error.go b/protocol/fsd_error.go new file mode 100644 index 0000000..031b780 --- /dev/null +++ b/protocol/fsd_error.go @@ -0,0 +1,101 @@ +package protocol + +import ( + "fmt" + "strconv" + "strings" +) + +// FSDError combines an FSD error with a golang error, so they can be used interchangeably + +type ErrorCode int + +const ( + OkError = iota + CallsignInUseError + CallsignInvalidError + AlreadyRegisteredError + SyntaxError + PDUSourceInvalidError + InvalidLogonError + NoSuchCallsignError + NoFlightPlanError + NoWeatherProfileError + InvalidProtocolRevisionError + RequestedLevelTooHighError + ServerFullError + CertificateSuspendedError + InvalidControlError + InvalidPositionForRatingError + UnauthorizedSoftwareError + ClientAuthenticationResponseTimeoutError +) + +var genericErrorMessage = map[ErrorCode]string{ + OkError: "OK", + CallsignInUseError: "Callsign already in use", + CallsignInvalidError: "Invalid callsign", + AlreadyRegisteredError: "Already registered", + SyntaxError: "Syntax error", + PDUSourceInvalidError: "PDU source invalid", + InvalidLogonError: "Invalid login information", + NoSuchCallsignError: "No such callsign", + NoFlightPlanError: "No flight plan filed", + NoWeatherProfileError: "No weather profile", + InvalidProtocolRevisionError: "Invalid protocol revision", + RequestedLevelTooHighError: "Requested level too high", + ServerFullError: "Server full", + CertificateSuspendedError: "Certificate suspended", + InvalidControlError: "Invalid control", + InvalidPositionForRatingError: "Invalid position for rating", + UnauthorizedSoftwareError: "Unauthorized software", + ClientAuthenticationResponseTimeoutError: "Client authentication response timeout", +} + +type FSDError struct { + From string + To string + Code ErrorCode + Param string + Message string +} + +func (e *FSDError) Error() string { + return genericErrorMessage[e.Code] +} + +func (e *FSDError) Serialize() string { + return fmt.Sprintf("$ER%s:%s:%03d:%s:%s%s", e.From, e.To, e.Code, e.Param, e.Message, PacketDelimeter) +} + +func ParseNetworkErrorPDU(rawPacket string) (*FSDError, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$ER") + fields := strings.SplitN(rawPacket, Delimeter, 5) + if len(fields) < 5 { + return nil, NewGenericFSDError(SyntaxError) + } + + errCode, err := strconv.Atoi(fields[2]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &FSDError{ + From: fields[0], + To: fields[1], + Code: ErrorCode(errCode), + Param: fields[3], + Message: fields[4], + }, nil +} + +func NewGenericFSDError(code ErrorCode) *FSDError { + return &FSDError{ + From: "server", + To: "unknown", + Code: code, + Param: "", + Message: genericErrorMessage[code], + } +} diff --git a/protocol/fsd_error_test.go b/protocol/fsd_error_test.go new file mode 100644 index 0000000..f93c718 --- /dev/null +++ b/protocol/fsd_error_test.go @@ -0,0 +1,67 @@ +package protocol + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseNetworkErrorPDU(t *testing.T) { + + // Normal error test + { + p := "$ERN123:SERVER:0::this is a test error packet\r\n" + pdu, err := ParseNetworkErrorPDU(p) + assert.Nil(t, err) + assert.NotNil(t, pdu) + assert.Equal(t, "N123", pdu.From) + assert.Equal(t, "SERVER", pdu.To) + assert.Equal(t, ErrorCode(0), pdu.Code) + assert.Equal(t, "", pdu.Param) + assert.Equal(t, "this is a test error packet", pdu.Message) + } + + // Test with several colons in message + { + p := "$ERN123:SERVER:0::here:are:a:lot:of:colons\r\n" + pdu, err := ParseNetworkErrorPDU(p) + assert.Nil(t, err) + assert.NotNil(t, pdu) + assert.Equal(t, "N123", pdu.From) + assert.Equal(t, "SERVER", pdu.To) + assert.Equal(t, ErrorCode(0), pdu.Code) + assert.Equal(t, "", pdu.Param) + assert.Equal(t, "here:are:a:lot:of:colons", pdu.Message) + } + + // Missing field + { + p := "$ERN123:SERVER::here:are:a:lot:of:colons\r\n" + pdu, err := ParseNetworkErrorPDU(p) + assert.NotNil(t, err) + assert.IsType(t, &FSDError{}, err) + assert.Nil(t, pdu) + } + + // Non-integer error code + { + p := "$ERN123:SERVER:this-is-clearly-not-a-number::this is a test error packet\r\n" + pdu, err := ParseNetworkErrorPDU(p) + assert.NotNil(t, err) + assert.IsType(t, &FSDError{}, err) + assert.Nil(t, pdu) + } +} + +func TestFSDError_Serialize(t *testing.T) { + { + pdu := FSDError{ + From: "N123", + To: "SERVER", + Code: 5, + Param: "", + Message: "this is a test error packet", + } + s := pdu.Serialize() + assert.Equal(t, "$ERN123:SERVER:005::this is a test error packet\r\n", s) + } +} diff --git a/protocol/kill_request.go b/protocol/kill_request.go new file mode 100644 index 0000000..c139b2f --- /dev/null +++ b/protocol/kill_request.go @@ -0,0 +1,50 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type KillRequestPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + Reason string `validate:"max=256"` +} + +func (p *KillRequestPDU) Serialize() string { + if p.Reason == "" { + return fmt.Sprintf("$!!%s:%s%s", p.From, p.To, PacketDelimeter) + } else { + return fmt.Sprintf("$!!%s:%s:%s%s", p.From, p.To, p.Reason, PacketDelimeter) + } +} + +func ParseKillRequestPDU(rawPacket string) (*KillRequestPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$!!") + fields := strings.SplitN(rawPacket, Delimeter, 3) + + if len(fields) < 2 { + return nil, NewGenericFSDError(SyntaxError) + } + + var reason string + if len(fields) == 3 { + reason = fields[2] + } else { + reason = "" + } + + pdu := KillRequestPDU{ + From: fields[0], + To: fields[1], + Reason: reason, + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/kill_request_test.go b/protocol/kill_request_test.go new file mode 100644 index 0000000..762ea23 --- /dev/null +++ b/protocol/kill_request_test.go @@ -0,0 +1,111 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseKillRequestPDU(t *testing.T) { + V = validator.New() + + tests := []struct { + name string + packet string + want *KillRequestPDU + wantErr error + }{ + { + name: "Valid Reason", + packet: "$!!JOHN:DOE:you're banned: reason\r\n", + want: &KillRequestPDU{ + From: "JOHN", + To: "DOE", + Reason: "you're banned: reason", + }, + wantErr: nil, + }, + { + name: "Valid with no Reason", + packet: "$!!JOHN:DOE\r\n", + want: &KillRequestPDU{ + From: "JOHN", + To: "DOE", + Reason: "", + }, + wantErr: nil, + }, + { + name: "Invalid from field", + packet: "$!!JOHN99999:DOE:you're banned: reason\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid to field", + packet: "$!!JOHN:DOE1234567:you're banned: reason\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Missing to field", + packet: "$!!JOHN::Hello, world!\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseKillRequestPDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestKillRequestPDU_Serialize(t *testing.T) { + tests := []struct { + name string + textPDU KillRequestPDU + wantOutput string + }{ + { + name: "Valid", + textPDU: KillRequestPDU{ + From: "ALPHA1", + To: "BRAVO2", + Reason: "Hello, this is a test Reason.", + }, + wantOutput: "$!!ALPHA1:BRAVO2:Hello, this is a test Reason.\r\n", + }, + { + name: "Valid without Reason", + textPDU: KillRequestPDU{ + From: "ALPHA1", + To: "BRAVO2", + Reason: "", + }, + wantOutput: "$!!ALPHA1:BRAVO2\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform serialization + output := tc.textPDU.Serialize() + + // Verify the output + assert.Equal(t, tc.wantOutput, output) + }) + } +} diff --git a/protocol/pdu.go b/protocol/pdu.go new file mode 100644 index 0000000..ad2204e --- /dev/null +++ b/protocol/pdu.go @@ -0,0 +1,103 @@ +package protocol + +import "github.com/go-playground/validator/v10" + +var V *validator.Validate + +// Misc FSD protocol values +const ( + ClientQueryBroadcastRecipient = "@94835" + ClientQueryBroadcastRecipientPilots = "@94386" + Delimeter = ":" + PacketDelimeter = "\r\n" + ServerCallsign = "SERVER" +) + +const ( + NetworkFacilityOBS = iota + NetworkFacilityFSS + NetworkFacilityDEL + NetworkFacilityGND + NetworkFacilityTWR + NetworkFacilityAPP + NetworkFacilityCTR +) + +const ( + NetworkRatingUnknown = iota + NetworkRatingOBS + NetworkRatingS1 + NetworkRatingS2 + NetworkRatingS3 + NetworkRatingC1 + NetworkRatingC2 + NetworkRatingC3 + NetworkRatingI1 + NetworkRatingI2 + NetworkRatingI3 + NetworkRatingSUP + NetworkRatingADM +) + +const ( + SimulatorTypeUnknown = iota + SimulatorTypeMSFS95 + SimulatorTypeMSFS98 + SimulatorTypeMSCFS + SimulatorTypeMSFS2000 + SimulatorTypeMSCFS2 + SimulatorTypeMSFS2002 + SimulatorTypeMSCFS3 + SimulatorTypeMSFS2004 + SimulatorTypeMSFSX + SimulatorTypeXPlane8 = iota + 2 + SimulatorTypeXPlane9 + SimulatorTypeXPlane10 + SimulatorTypeXPlane11 = iota + 3 + SimulatorTypeFlightGear = iota + 11 + SimulatorTypeP3D = iota + 15 +) + +const ( + ProtoRevisionUnknown = 0 + ProtoRevisionClassic = 9 + ProtoRevisionVatsimNoAuth = 10 + ProtoRevisionVatsimAuth = 100 + ProtoRevisionVatsim2022 = 101 +) + +const ( + ClientQueryUnknown = "" + ClientQueryIsValidATC = "ATC" + ClientQueryCapabilities = "CAPS" + ClientQueryCOM1Freq = "C?" + ClientQueryRealName = "RN" + ClientQueryServer = "SV" + ClientQueryATIS = "ATIS" + ClientQueryPublicIP = "IP" + ClientQueryINF = "INF" + ClientQueryFlightPlan = "FP" + ClientQueryIPC = "IPC" + ClientQueryRequestRelief = "BY" + ClientQueryCancelRequestRelief = "HI" + ClientQueryRequestHelp = "HLP" + ClientQueryCancelRequestHelp = "NOHLP" + ClientQueryWhoHas = "WH" + ClientQueryInitiateTrack = "IT" + ClientQueryAcceptHandoff = "HT" + ClientQueryDropTrack = "DR" + ClientQuerySetFinalAltitude = "FA" + ClientQuerySetTempAltitude = "TA" + ClientQuerySetBeaconCode = "BC" + ClientQuerySetScratchpad = "SC" + ClientQuerySetVoiceType = "VT" + ClientQueryAircraftConfiguration = "ACC" + ClientQueryNewInfo = "NEWINFO" + ClientQueryNewATIS = "NEWATIS" + ClientQueryEstimate = "EST" + ClientQuerySetGlobalData = "GD" +) + +type PDU interface { + Serialize() string +} diff --git a/protocol/pilot_position.go b/protocol/pilot_position.go new file mode 100644 index 0000000..828d3d3 --- /dev/null +++ b/protocol/pilot_position.go @@ -0,0 +1,169 @@ +package protocol + +import ( + "fmt" + "strconv" + "strings" +) + +type PilotPositionPDU struct { + SquawkingModeC bool `validate:""` + Identing bool `validate:""` + From string `validate:"required,alphanum,max=7"` + SquawkCode string `validate:"len=4"` + NetworkRating int `validate:"required,min=1,max=12"` + Lat float64 `validate:"min=-90.0,max=90.0"` + Lng float64 `validate:"min=-180.0,max=180.0"` + TrueAltitude int `validate:"min=-1500,max=99999"` + PressureAltitude int `validate:"min=-1500,max=99999"` + GroundSpeed int `validate:"min=0,max=9999"` + Pitch float64 `validate:"min=-360.0,max=360.0"` + Heading float64 `validate:"min=-360.0,max=360.0"` + Bank float64 `validate:"min=-360.0,max=360.0"` +} + +func packPitchBankHeading(pitch, bank, heading float64) uint32 { + p := pitch / -360.0 + if p < 0 { + p += 1.0 + } + p *= 1024.0 + + b := bank / -360.0 + if b < 0 { + b += 1.0 + } + b *= 1024.0 + + h := heading / 360.0 * 1024.0 + + return uint32(p)<<22 | uint32(b)<<12 | uint32(h)<<2 +} + +func unpackPitchBankHeading(pbh uint32) (pitch, bank, heading float64) { + pitchInt := pbh >> 22 + pitch = float64(pitchInt) / 1024.0 * -360.0 + if pitch > 180.0 { + pitch -= 360.0 + } else if pitch <= -180.0 { + pitch += 360.0 + } + + bankInt := (pbh >> 12) & 0x3FF + bank = float64(bankInt) / 1024.0 * -360.0 + if bank > 180.0 { + bank -= 360.0 + } else if bank <= -180.0 { + bank += 360.0 + } + + headingInt := (pbh >> 2) & 0x3FF + heading = float64(headingInt) / 1024.0 * 360.0 + if heading < 0.0 { + heading += 360.0 + } else if heading >= 360.0 { + heading -= 360.0 + } + + return +} + +func (p *PilotPositionPDU) Serialize() string { + xpdrStateStr := "S" + if p.Identing { + xpdrStateStr = "Y" + } else if p.SquawkingModeC { + xpdrStateStr = "N" + } + + return fmt.Sprintf("@%s:%s:%s:%d:%.6f:%.6f:%d:%d:%d:%d%s", xpdrStateStr, p.From, p.SquawkCode, p.NetworkRating, p.Lat, p.Lng, p.TrueAltitude, p.GroundSpeed, packPitchBankHeading(p.Pitch, p.Bank, p.Heading), p.PressureAltitude-p.TrueAltitude, PacketDelimeter) +} + +func ParsePilotPositionPDU(rawPacket string) (*PilotPositionPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "@") + fields := strings.Split(rawPacket, Delimeter) + if len(fields) != 10 { + return nil, NewGenericFSDError(SyntaxError) + } + + networkRating, err := strconv.Atoi(fields[3]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + lat, err := strconv.ParseFloat(fields[4], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + lng, err := strconv.ParseFloat(fields[5], 64) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + trueAlt, err := strconv.Atoi(fields[6]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + groundspeed, err := strconv.Atoi(fields[7]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + pbh64, err := strconv.ParseUint(fields[8], 10, 32) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + pbh := uint32(pbh64) + + pitch, bank, heading := unpackPitchBankHeading(pbh) + + pressureAlt, err := strconv.Atoi(fields[9]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + pressureAlt += trueAlt + + identing := false + modeC := false + if fields[0] == "N" { + modeC = true + } else if fields[0] == "Y" { + modeC = true + identing = true + } else if fields[0] != "S" { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := PilotPositionPDU{ + SquawkingModeC: modeC, + Identing: identing, + From: fields[1], + SquawkCode: fields[2], + NetworkRating: networkRating, + Lat: lat, + Lng: lng, + TrueAltitude: trueAlt, + PressureAltitude: pressureAlt, + GroundSpeed: groundspeed, + Pitch: pitch, + Heading: bank, + Bank: heading, + } + + err = V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + // Check transponder code validity + for _, char := range pdu.SquawkCode { + if char < '0' || char > '7' { + return nil, NewGenericFSDError(SyntaxError) + } + } + + return &pdu, nil +} diff --git a/protocol/pilot_position_test.go b/protocol/pilot_position_test.go new file mode 100644 index 0000000..7a1252d --- /dev/null +++ b/protocol/pilot_position_test.go @@ -0,0 +1,286 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParsePilotPositionPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *PilotPositionPDU + wantErr error + }{ + { + "Valid (squawk standby)", + "@S:N123:1200:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + &PilotPositionPDU{ + SquawkingModeC: false, + Identing: false, + From: "N123", + SquawkCode: "1200", + NetworkRating: 1, + Lat: 40.6452667, + Lng: -73.7738611, + TrueAltitude: 16, + GroundSpeed: 0, + Pitch: 10, + Bank: 10, + Heading: 10, + PressureAltitude: 352, + }, + nil, + }, + { + "Valid (squawk normal)", + "@N:N123:1200:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + &PilotPositionPDU{ + SquawkingModeC: true, + Identing: false, + From: "N123", + SquawkCode: "1200", + NetworkRating: 1, + Lat: 40.6452667, + Lng: -73.7738611, + TrueAltitude: 16, + GroundSpeed: 0, + Pitch: 10, + Bank: 10, + Heading: 10, + PressureAltitude: 352, + }, + nil, + }, + { + "Valid (squawk ident)", + "@Y:N123:1200:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + &PilotPositionPDU{ + SquawkingModeC: true, + Identing: true, + From: "N123", + SquawkCode: "1200", + NetworkRating: 1, + Lat: 40.6452667, + Lng: -73.7738611, + TrueAltitude: 16, + GroundSpeed: 0, + Pitch: 10, + Bank: 10, + Heading: 10, + PressureAltitude: 352, + }, + nil, + }, + { + "Invalid transponder mode", + "@foo:N123:1200:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing transponder mode", + "@:N123:1200:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid callsign", + "@N:N172SP99:1200:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing callsign", + "@N::1200:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid transponder code length", + "@N:N123:00000:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Out of range transponder code", + "@N:N123:7778:1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing transponder code", + "@N:N123::1:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Network rating too low", + "@N:N123:1200:0:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Network rating too high", + "@N:N123:1200:13:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "NaN network rating", + "@N:N123:1200:foo:40.6452667:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Out of bounds lat", + "@N:N123:1200:1:-91.1231:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing lat", + "@N:N123:1200:1::-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "NaN lat", + "@N:N123:1200:1:foo:-73.7738611:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Out of bounds lng", + "@N:N123:1200:1:40.6452667:181.32133:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing lng", + "@N:N123:1200:1:40.6452667::16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "NaN lng", + "@N:N123:1200:1:40.6452667:foo:16:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Too low true alt", + "@N:N123:1200:1:40.6452667:-73.7738611:-1501:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Too high true alt", + "@N:N123:1200:1:40.6452667:-73.7738611:100000:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing true alt", + "@N:N123:1200:1:40.6452667:-73.7738611::0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "NaN true alt", + "@N:N123:1200:1:40.6452667:-73.7738611:foo:0:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Negative groundspeed", + "@N:N123:1200:1:40.6452667:-73.7738611:16:-1:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Unrealistic groundspeed", + "@N:N123:1200:1:40.6452667:-73.7738611:16:10000:4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing groundspeed", + "@N:N123:1200:1:40.6452667:-73.7738611:16::4177408112:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "pbh uint32 too long", + "@N:N123:1200:1:40.6452667:-73.7738611:16:0:4294967296:336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "pbh missing", + "@N:N123:1200:1:40.6452667:-73.7738611:16:0::336\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Pressure alt too high", + "@N:N123:1200:1:40.6452667:-73.7738611:16:0:4177408112:999999\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Pressure alt too low", + "@N:N123:1200:1:40.6452667:-73.7738611:16:0:4177408112:-10016\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Missing pressure alt", + "@N:N123:1200:1:40.6452667:-73.7738611:16:0:4177408112:\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Empty packet", + "@\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParsePilotPositionPDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + if tc.want != nil { + assert.Equal(t, tc.want.SquawkingModeC, result.SquawkingModeC) + assert.Equal(t, tc.want.Identing, result.Identing) + assert.Equal(t, tc.want.From, result.From) + assert.Equal(t, tc.want.SquawkCode, result.SquawkCode) + assert.Equal(t, tc.want.NetworkRating, result.NetworkRating) + assert.InDelta(t, tc.want.Lat, result.Lat, 1e-3) + assert.InDelta(t, tc.want.Lng, result.Lng, 1e-3) + assert.Equal(t, tc.want.TrueAltitude, result.TrueAltitude) + assert.Equal(t, tc.want.GroundSpeed, result.GroundSpeed) + assert.InDelta(t, tc.want.Pitch, result.Pitch, 1) + assert.InDelta(t, tc.want.Bank, result.Bank, 1) + assert.InDelta(t, tc.want.Heading, result.Heading, 1) + assert.Equal(t, tc.want.PressureAltitude, result.PressureAltitude) + } else { + assert.Nil(t, result) + } + }) + } +} diff --git a/protocol/ping.go b/protocol/ping.go new file mode 100644 index 0000000..a7ffbd4 --- /dev/null +++ b/protocol/ping.go @@ -0,0 +1,39 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type PingPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + Timestamp string `validate:"required,max=32"` +} + +func (p *PingPDU) Serialize() string { + return fmt.Sprintf("$PI%s:%s:%s%s", p.From, p.To, p.Timestamp, PacketDelimeter) +} + +func ParsePingPDU(rawPacket string) (*PingPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$PI") + fields := strings.Split(rawPacket, Delimeter) + + if len(fields) != 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := PingPDU{ + From: fields[0], + To: fields[1], + Timestamp: fields[2], + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/ping_test.go b/protocol/ping_test.go new file mode 100644 index 0000000..4c2d2aa --- /dev/null +++ b/protocol/ping_test.go @@ -0,0 +1,91 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestPingPDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu PingPDU + expected string + }{ + { + name: "Valid PingPDU", + pdu: PingPDU{ + From: "SOURCE", + To: "TARGET", + Timestamp: "1609459200", + }, + expected: "$PISOURCE:TARGET:1609459200\r\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.pdu.Serialize()) + }) + } +} + +func TestParsePingPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *PingPDU + wantErr error + }{ + { + name: "Valid PingPDU Parse", + packet: "$PISOURCE:TARGET:1609459200\r\n", + want: &PingPDU{ + From: "SOURCE", + To: "TARGET", + Timestamp: "1609459200", + }, + wantErr: nil, + }, + { + name: "Missing fields", + packet: "$PISOURCE:1609459200\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Exceeds max field length", + packet: "$PISOURCESOURCE:TARGETTARGET:16094592000000000000000000000000\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid From field", + packet: "$PI12345678:TARGET:1609459200\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid To field", + packet: "$PISOURCE:12345678:1609459200\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := ParsePingPDU(tc.packet) + + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/plane_info_request.go b/protocol/plane_info_request.go new file mode 100644 index 0000000..8ab7b43 --- /dev/null +++ b/protocol/plane_info_request.go @@ -0,0 +1,40 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type PlaneInfoRequestPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` +} + +func (p *PlaneInfoRequestPDU) Serialize() string { + return fmt.Sprintf("#SB%s:%s:PIR%s", p.From, p.To, PacketDelimeter) +} + +func ParsePlaneInfoRequestPDU(rawPacket string) (*PlaneInfoRequestPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "#SB") + fields := strings.Split(rawPacket, Delimeter) + if len(fields) != 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + if fields[2] != "PIR" { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := &PlaneInfoRequestPDU{ + From: fields[0], + To: fields[1], + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return pdu, nil +} diff --git a/protocol/plane_info_request_test.go b/protocol/plane_info_request_test.go new file mode 100644 index 0000000..4f0f707 --- /dev/null +++ b/protocol/plane_info_request_test.go @@ -0,0 +1,131 @@ +package protocol + +import ( + "testing" + + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" +) + +func TestParsePlaneInfoRequestPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *PlaneInfoRequestPDU + wantErr error + }{ + { + name: "Valid input", + packet: "#SBPILOT:ATC:PIR\r\n", + want: &PlaneInfoRequestPDU{ + From: "PILOT", + To: "ATC", + }, + wantErr: nil, + }, + { + name: "Last element not PIR", + packet: "#SBPILOT:ATC:PI\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Missing To field", + packet: "#SBPILOT::PIR\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Missing From field", + packet: "#SB:ATC:PIR\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Extra fields", + packet: "#SBPILOT:ATC:EXTRA:PIR\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid From field (non-alphanumerical)", + packet: "#SBP!@#$:ATC:PIR\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid To field (too long)", + packet: "#SBPILOT:12345678:PIR\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Input with incorrect prefix", + packet: "$DIPILOT:ATC:PIR\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParsePlaneInfoRequestPDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestPlaneInfoRequestPDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu PlaneInfoRequestPDU + want string + }{ + { + name: "Valid serialization", + pdu: PlaneInfoRequestPDU{ + From: "PILOT", + To: "ATC", + }, + want: "#SBPILOT:ATC:PIR\r\n", // assuming correct values for constants + }, + { + name: "Minimal length fields", + pdu: PlaneInfoRequestPDU{ + From: "A", + To: "B", + }, + want: "#SBA:B:PIR\r\n", // assuming correct values for constants + }, + { + name: "Max length fields", + pdu: PlaneInfoRequestPDU{ + From: "1234567", + To: "7654321", + }, + want: "#SB1234567:7654321:PIR\r\n", // assuming correct values for constants + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform serialization + result := tc.pdu.Serialize() + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/plane_info_response.go b/protocol/plane_info_response.go new file mode 100644 index 0000000..1aa0a5e --- /dev/null +++ b/protocol/plane_info_response.go @@ -0,0 +1,88 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type PlaneInfoResponsePDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + Equipment string `validate:"required,max=32"` + Airline string `validate:"max=32"` + Livery string `validate:"max=32"` + CSL string `validate:"max=32"` +} + +func (p *PlaneInfoResponsePDU) Serialize() string { + str := fmt.Sprintf("#SB%s:%s:PI:GEN:EQUIPMENT=%s", p.From, p.To, p.Equipment) + if p.Airline != "" { + str += fmt.Sprintf(":AIRLINE=%s", p.Airline) + } + if p.Livery != "" { + str += fmt.Sprintf(":LIVERY=%s", p.Livery) + } + if p.CSL != "" { + str += fmt.Sprintf(":CSL=%s", p.CSL) + } + + str += PacketDelimeter + return str +} + +func ParsePlaneInfoResponsePDU(rawPacket string) (*PlaneInfoResponsePDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "#SB") + fields := strings.Split(rawPacket, Delimeter) + if len(fields) < 5 || len(fields) > 8 { + return nil, NewGenericFSDError(SyntaxError) + } + + if fields[2] != "PI" || fields[3] != "GEN" { + return nil, NewGenericFSDError(SyntaxError) + } + + var equipment, airline, livery, csl string + + if !strings.HasPrefix(fields[4], "EQUIPMENT=") || len(fields[4]) < len("EQUIPMENT=")+1 { + return nil, NewGenericFSDError(SyntaxError) + } + equipment = strings.SplitN(fields[4], "=", 2)[1] + + if len(fields) > 5 { + if !strings.HasPrefix(fields[5], "AIRLINE=") || len(fields[5]) < len("AIRLINE=")+1 { + return nil, NewGenericFSDError(SyntaxError) + } + airline = strings.SplitN(fields[5], "=", 2)[1] + } + + if len(fields) > 6 { + if !strings.HasPrefix(fields[6], "LIVERY=") || len(fields[6]) < len("LIVERY=")+1 { + return nil, NewGenericFSDError(SyntaxError) + } + livery = strings.SplitN(fields[6], "=", 2)[1] + } + + if len(fields) > 7 { + if !strings.HasPrefix(fields[7], "CSL=") || len(fields[6]) < len("CSL=")+1 { + return nil, NewGenericFSDError(SyntaxError) + } + csl = strings.SplitN(fields[7], "=", 2)[1] + } + + pdu := &PlaneInfoResponsePDU{ + From: fields[0], + To: fields[1], + Equipment: equipment, + Airline: airline, + Livery: livery, + CSL: csl, + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return pdu, nil +} diff --git a/protocol/plane_info_response_test.go b/protocol/plane_info_response_test.go new file mode 100644 index 0000000..594929c --- /dev/null +++ b/protocol/plane_info_response_test.go @@ -0,0 +1,169 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParsePlaneInfoResponsePDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *PlaneInfoResponsePDU + wantErr error + }{ + { + "Valid Info With All Fields", + "#SBATC:PILOT:PI:GEN:EQUIPMENT=A320:AIRLINE=Delta:LIVERY=Standard:CSL=ModelABC\r\n", + &PlaneInfoResponsePDU{ + From: "ATC", + To: "PILOT", + Equipment: "A320", + Airline: "Delta", + Livery: "Standard", + CSL: "ModelABC", + }, + nil, + }, + { + "Valid Info With Minimum Required Fields", + "#SBATC:PILOT:PI:GEN:EQUIPMENT=A320\r\n", + &PlaneInfoResponsePDU{ + From: "ATC", + To: "PILOT", + Equipment: "A320", + Airline: "", + Livery: "", + CSL: "", + }, + nil, + }, + { + "Invalid - Missing EQUIPMENT Prefix", + "#SBATC:PILOT:PI:GEN:A320:LIVERY=Standard\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid - Wrong HEADER Prefix", + "$SBATC:PILOT:PI:GEN:EQUIPMENT=A320\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid - Field Count Less", + "#SBATC:PILOT:PI:GEN\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid - Field Count More", + "#SBATC:PILOT:PI:GEN:EQUIPMENT=A320:AIRLINE=Delta:LIVERY=Standard:CSL=ModelABC:ExtraField\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid - No From Field", + "#SB:PILOT:PI:GEN:EQUIPMENT=A320\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Invalid - No To Field", + "#SBATC::PI:GEN:EQUIPMENT=A320\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParsePlaneInfoResponsePDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.Error(t, err) + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestPlaneInfoResponsePDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu *PlaneInfoResponsePDU + wantStr string + }{ + { + "All Fields Present", + &PlaneInfoResponsePDU{ + From: "ATC", + To: "PILOT", + Equipment: "A320", + Airline: "Delta", + Livery: "Standard", + CSL: "ModelABC", + }, + "#SBATC:PILOT:PI:GEN:EQUIPMENT=A320:AIRLINE=Delta:LIVERY=Standard:CSL=ModelABC\r\n", + }, + { + "Only Required Fields", + &PlaneInfoResponsePDU{ + From: "ATC", + To: "PILOT", + Equipment: "B737", + }, + "#SBATC:PILOT:PI:GEN:EQUIPMENT=B737\r\n", + }, + { + "With Airline", + &PlaneInfoResponsePDU{ + From: "CTRL", + To: "PLANE", + Equipment: "E170", + Airline: "United", + }, + "#SBCTRL:PLANE:PI:GEN:EQUIPMENT=E170:AIRLINE=United\r\n", + }, + { + "With Livery", + &PlaneInfoResponsePDU{ + From: "GROUND", + To: "ACFT", + Equipment: "CRJ2", + Livery: "BlueSky", + }, + "#SBGROUND:ACFT:PI:GEN:EQUIPMENT=CRJ2:LIVERY=BlueSky\r\n", + }, + { + "With CSL", + &PlaneInfoResponsePDU{ + From: "TWR", + To: "JET", + Equipment: "CONC", + CSL: "Vintage", + }, + "#SBTWR:JET:PI:GEN:EQUIPMENT=CONC:CSL=Vintage\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the serialization + gotStr := tc.pdu.Serialize() + + // Verify the serialized output + assert.Equal(t, tc.wantStr, gotStr) + }) + } +} diff --git a/protocol/pong.go b/protocol/pong.go new file mode 100644 index 0000000..e1ca45f --- /dev/null +++ b/protocol/pong.go @@ -0,0 +1,39 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type PongPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + Timestamp string `validate:"required,max=32"` +} + +func (p *PongPDU) Serialize() string { + return fmt.Sprintf("$PO%s:%s:%s%s", p.From, p.To, p.Timestamp, PacketDelimeter) +} + +func ParsePongPDU(rawPacket string) (*PongPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$PO") + fields := strings.Split(rawPacket, Delimeter) + + if len(fields) != 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := PongPDU{ + From: fields[0], + To: fields[1], + Timestamp: fields[2], + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/pong_test.go b/protocol/pong_test.go new file mode 100644 index 0000000..a2a71dd --- /dev/null +++ b/protocol/pong_test.go @@ -0,0 +1,91 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestPongPDU_Serialize(t *testing.T) { + tests := []struct { + name string + pdu PongPDU + expected string + }{ + { + name: "Valid PongPDU", + pdu: PongPDU{ + From: "SOURCE", + To: "TARGET", + Timestamp: "1609459200", + }, + expected: "$POSOURCE:TARGET:1609459200\r\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.pdu.Serialize()) + }) + } +} + +func TestParsePongPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *PongPDU + wantErr error + }{ + { + name: "Valid PongPDU Parse", + packet: "$POSOURCE:TARGET:1609459200\r\n", + want: &PongPDU{ + From: "SOURCE", + To: "TARGET", + Timestamp: "1609459200", + }, + wantErr: nil, + }, + { + name: "Missing fields", + packet: "$POSOURCE:1609459200\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Exceeds max field length", + packet: "$POSOURCESOURCE:TARGETTARGET:16094592000000000000000000000000\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid From field", + packet: "$PO12345678:TARGET:1609459200\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid To field", + packet: "$POSOURCE:12345678:1609459200\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := ParsePongPDU(tc.packet) + + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/send_fast.go b/protocol/send_fast.go new file mode 100644 index 0000000..8930947 --- /dev/null +++ b/protocol/send_fast.go @@ -0,0 +1,50 @@ +package protocol + +import ( + "fmt" + "strconv" + "strings" +) + +type SendFastPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + DoSendFast bool `validate:""` +} + +func (p *SendFastPDU) Serialize() string { + var doSendFastInt int + if p.DoSendFast { + doSendFastInt = 1 + } + return fmt.Sprintf("$SF%s:%s:%d%s", p.From, p.To, doSendFastInt, PacketDelimeter) +} + +func ParseSendFastPDU(rawPacket string) (*SendFastPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$SF") + fields := strings.Split(rawPacket, Delimeter) + + if len(fields) != 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + doSendFastInt, err := strconv.Atoi(fields[2]) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + doSendFast := doSendFastInt == 1 + + pdu := SendFastPDU{ + From: fields[0], + To: fields[1], + DoSendFast: doSendFast, + } + + err = V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/send_fast_test.go b/protocol/send_fast_test.go new file mode 100644 index 0000000..6139806 --- /dev/null +++ b/protocol/send_fast_test.go @@ -0,0 +1,113 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSendFastPDU_Serialize(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + pdu SendFastPDU + want string + }{ + { + name: "SendFastPDU with DoSendFast true", + pdu: SendFastPDU{ + From: "SERVER", + To: "CLIENT", + DoSendFast: true, + }, + want: "$SFSERVER:CLIENT:1\r\n", + }, + { + name: "SendFastPDU with DoSendFast false", + pdu: SendFastPDU{ + From: "SERVER", + To: "CLIENT", + DoSendFast: false, + }, + want: "$SFSERVER:CLIENT:0\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the serialization + result := tc.pdu.Serialize() + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestParseSendFastPDU(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *SendFastPDU + wantErr bool + }{ + { + name: "Valid packet with DoSendFast true", + packet: "$SFSERVER:CLIENT:1\r\n", + want: &SendFastPDU{ + From: "SERVER", + To: "CLIENT", + DoSendFast: true, + }, + wantErr: false, + }, + { + name: "Valid packet with DoSendFast false", + packet: "$SFSERVER:CLIENT:0\r\n", + want: &SendFastPDU{ + From: "SERVER", + To: "CLIENT", + DoSendFast: false, + }, + wantErr: false, + }, + { + name: "Packet with missing DoSendFast field", + packet: "$SFSERVER:CLIENT:\r\n", + want: nil, + wantErr: true, + }, + { + name: "Packet with invalid DoSendFast field", + packet: "$SFSERVER:CLIENT:not_a_number\r\n", + want: nil, + wantErr: true, + }, + { + name: "Incorrect packet format", + packet: "SFSERVERCLIENT0\r\n", + want: nil, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseSendFastPDU(tc.packet) + + // Check the error + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} diff --git a/protocol/server_identification.go b/protocol/server_identification.go new file mode 100644 index 0000000..7a6925f --- /dev/null +++ b/protocol/server_identification.go @@ -0,0 +1,40 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type ServerIdentificationPDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,alphanum,max=7"` + Version string `validate:"required,max=32"` + InitialChallenge string `validate:"required,hexadecimal,max=32"` +} + +func (p *ServerIdentificationPDU) Serialize() string { + return fmt.Sprintf("$DI%s:%s:%s:%s%s", p.From, p.To, p.Version, p.InitialChallenge, PacketDelimeter) +} + +func ParseServerIdentificationPDU(rawPacket string) (*ServerIdentificationPDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "$DI") + fields := strings.Split(rawPacket, Delimeter) + if len(fields) != 4 { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := &ServerIdentificationPDU{ + From: fields[0], + To: fields[1], + Version: fields[2], + InitialChallenge: fields[3], + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return pdu, nil +} diff --git a/protocol/server_identification_test.go b/protocol/server_identification_test.go new file mode 100644 index 0000000..5786213 --- /dev/null +++ b/protocol/server_identification_test.go @@ -0,0 +1,79 @@ +package protocol + +import ( + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestServerIdentificationPDU_Serialize2(t *testing.T) { + V = validator.New(validator.WithRequiredStructEnabled()) + + tests := []struct { + name string + packet string + want *ServerIdentificationPDU + wantErr error + }{ + { + "Valid", + "$DISERVER:CLIENT:fsd server:0123456789abcdef\r\n", + &ServerIdentificationPDU{ + From: "SERVER", + To: "CLIENT", + Version: "fsd server", + InitialChallenge: "0123456789abcdef", + }, + nil, + }, + { + "Missing field", + "$DI:CLIENT:fsd server:0123456789abcdef\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Non-hexadecimal challenge", + "$DISERVER:CLIENT:fsd server:ghijklmnop\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + { + "Challenge too long", + "$DISERVER:CLIENT:fsd server:fd9bb85563fc21920f352a74a0917ea88\r\n", + nil, + NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseServerIdentificationPDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestServerIdentificationPDU_Serialize(t *testing.T) { + { + pdu := ServerIdentificationPDU{ + From: "SERVER", + To: "CLIENT", + Version: "fsd server", + InitialChallenge: "12345", + } + + s := pdu.Serialize() + assert.Equal(t, "$DISERVER:CLIENT:fsd server:12345\r\n", s) + } +} diff --git a/protocol/text_message.go b/protocol/text_message.go new file mode 100644 index 0000000..10983f9 --- /dev/null +++ b/protocol/text_message.go @@ -0,0 +1,39 @@ +package protocol + +import ( + "fmt" + "strings" +) + +type TextMessagePDU struct { + From string `validate:"required,alphanum,max=7"` + To string `validate:"required,max=7"` + Message string `validate:"required"` +} + +func (p *TextMessagePDU) Serialize() string { + return fmt.Sprintf("#TM%s:%s:%s%s", p.From, p.To, p.Message, PacketDelimeter) +} + +func ParseTextMessagePDU(rawPacket string) (*TextMessagePDU, error) { + rawPacket = strings.TrimSuffix(rawPacket, PacketDelimeter) + rawPacket = strings.TrimPrefix(rawPacket, "#TM") + fields := strings.SplitN(rawPacket, Delimeter, 3) + + if len(fields) < 3 { + return nil, NewGenericFSDError(SyntaxError) + } + + pdu := TextMessagePDU{ + From: fields[0], + To: fields[1], + Message: fields[2], + } + + err := V.Struct(pdu) + if err != nil { + return nil, NewGenericFSDError(SyntaxError) + } + + return &pdu, nil +} diff --git a/protocol/text_message_test.go b/protocol/text_message_test.go new file mode 100644 index 0000000..386b7e0 --- /dev/null +++ b/protocol/text_message_test.go @@ -0,0 +1,99 @@ +package protocol + +import ( + "fmt" + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseTextMessagePDU(t *testing.T) { + V = validator.New() + + tests := []struct { + name string + packet string + want *TextMessagePDU + wantErr error + }{ + { + name: "Valid message", + packet: "#TMJOHN:DOE:Hello, world!\r\n", + want: &TextMessagePDU{ + From: "JOHN", + To: "DOE", + Message: "Hello, world!", + }, + wantErr: nil, + }, + { + name: "Invalid from field", + packet: "#TMJOHN99999:DOE:Hello, world!\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Invalid to field", + packet: "#TMJOHN:DOE1234567:Hello, world!\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Missing to field", + packet: "#TMJOHN::Hello, world!\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + { + name: "Missing message", + packet: "#TMJOHN:DOE:\r\n", + want: nil, + wantErr: NewGenericFSDError(SyntaxError), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the parsing + result, err := ParseTextMessagePDU(tc.packet) + + // Check the error + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + } else { + assert.NoError(t, err) + } + + // Verify the result + assert.Equal(t, tc.want, result) + }) + } +} + +func TestTextMessagePDU_Serialize(t *testing.T) { + tests := []struct { + name string + textPDU TextMessagePDU + wantOutput string + }{ + { + name: "Valid message", + textPDU: TextMessagePDU{ + From: "ALPHA1", + To: "BRAVO2", + Message: "Hello, this is a test message.", + }, + wantOutput: "#TMALPHA1:BRAVO2:Hello, this is a test message.\r\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform serialization + output := tc.textPDU.Serialize() + + // Verify the output + assert.Equal(t, tc.wantOutput, output, fmt.Sprintf("Test %s: expected serialized output does not match actual output", tc.name)) + }) + } +} diff --git a/vatsimauth/vatsimauth.go b/vatsimauth/vatsimauth.go new file mode 100644 index 0000000..18f4570 --- /dev/null +++ b/vatsimauth/vatsimauth.go @@ -0,0 +1,72 @@ +package vatsimauth + +import ( + "crypto/md5" + "crypto/rand" + "encoding/hex" +) + +var Keys = map[uint16]string{ + 35044: "fe28334fb753cf0e3d19942197b9ce3e", + 55538: "ImuL1WbbhVuD8d3MuKpWn2rrLZRa9iVP", + 47546: "727d1efd5cb9f8d2c28372469d922bb4", +} + +type VatsimAuth struct { + clientID uint16 + init string + state string +} + +func NewVatsimAuth(clientID uint16, privateKey string) *VatsimAuth { + return &VatsimAuth{ + clientID: clientID, + init: "", + state: privateKey, + } +} + +// SetInitialChallenge stores the initial challenge for this authentication state. This should be called once before running GenerateResponse +func (v *VatsimAuth) SetInitialChallenge(initialChallenge string) { + v.init = v.GenerateResponse(initialChallenge) + v.state = v.init +} + +// GenerateResponse returns the response for the provided challenge. +func (v *VatsimAuth) GenerateResponse(challenge string) string { + c1, c2 := challenge[0:(len(challenge)/2)], challenge[(len(challenge)/2):] + + if (v.clientID & 1) == 1 { + c1, c2 = c2, c1 + } + + s1, s2, s3 := v.state[0:12], v.state[12:22], v.state[22:32] + + h := "" + switch v.clientID % 3 { + case 0: + h = s1 + c1 + s2 + c2 + s3 + case 1: + h = s2 + c1 + s3 + c2 + s1 + } + + hash := md5.Sum([]byte(h)) + return hex.EncodeToString(hash[:]) +} + +// UpdateState updates this authentication state with a response hash. This is conventionally called with the return value of GenerateResponse +func (v *VatsimAuth) UpdateState(hash string) { + newStateHash := md5.Sum([]byte(v.init + hash)) + v.state = hex.EncodeToString(newStateHash[:]) +} + +// GenerateChallenge returns a cryptographically secure random challenge string +func GenerateChallenge() (string, error) { + challenge := make([]byte, 4) + _, err := rand.Read(challenge) + if err != nil { + return "", err + } + + return hex.EncodeToString(challenge), nil +} diff --git a/vatsimauth/vatsimauth_test.go b/vatsimauth/vatsimauth_test.go new file mode 100644 index 0000000..024d1cf --- /dev/null +++ b/vatsimauth/vatsimauth_test.go @@ -0,0 +1,36 @@ +package vatsimauth + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestVatsimAuth(t *testing.T) { + + // vPilot clientID and key + v := NewVatsimAuth(35044, "fe28334fb753cf0e3d19942197b9ce3e") + assert.NotNil(t, v) + + // Initial challenge + // $DISERVER:CLIENT:VATSIM FSD V3.43:30984979d8caed23 + v.SetInitialChallenge("30984979d8caed23") + + // First challenge from server + // $ZCSERVER:N12345:de6acb8e + res := v.GenerateResponse("de6acb8e") + + // Expected response: + // $ZRN12345:SERVER:f8ee97157f66455ed6108fccef6ccf5f + assert.Equal(t, "f8ee97157f66455ed6108fccef6ccf5f", res) + v.UpdateState(res) + + // Second challenge from server + // $ZCSERVER:N12345:65b479573b0e + res = v.GenerateResponse("65b479573b0e") + + // Expected response + // $ZRN12345:SERVER:8953f545c4e0ffd20943ad89b8ddd087 + assert.Equal(t, "8953f545c4e0ffd20943ad89b8ddd087", res) + v.UpdateState(res) + +}