Files
openfsd/client/connection.go
2024-10-22 12:46:06 -07:00

395 lines
10 KiB
Go

package client
import (
"bufio"
"context"
"errors"
"github.com/golang-jwt/jwt/v5"
auth2 "github.com/renorris/openfsd/auth"
"github.com/renorris/openfsd/database"
"github.com/renorris/openfsd/postoffice"
"github.com/renorris/openfsd/protocol"
"github.com/renorris/openfsd/protocol/vatsimauth"
"github.com/renorris/openfsd/servercontext"
"golang.org/x/crypto/bcrypt"
"log"
"net"
"strconv"
"strings"
"time"
)
const maxFSDPacketSize = 1536
const writeFlushInterval = 50 * time.Millisecond
var AlwaysImmediate = false
type writePayload struct {
packet string
immediate bool
}
type Connection struct {
ctx context.Context
cancelCtx func()
conn *net.TCPConn
readerChan chan string
writerChan chan writePayload
doneSig chan struct{}
}
func NewConnection(ctx context.Context, conn *net.TCPConn) (*Connection, error) {
// Ensure any lingering data will be flushed
if err := conn.SetLinger(1); err != nil {
return nil, err
}
c, cancel := context.WithCancel(ctx)
connection := Connection{
ctx: c,
cancelCtx: cancel,
conn: conn,
readerChan: make(chan string),
writerChan: make(chan writePayload),
doneSig: make(chan struct{}),
}
// Start reader
go func() {
defer connection.cancelCtx()
if err := connection.reader(); err != nil {
// Attempt to send the error to the client if it implements FSDError
var fsdError *protocol.FSDError
if errors.As(err, &fsdError) {
connection.WritePacket(fsdError.Serialize())
}
}
}()
// Start writer
go func() {
defer close(connection.doneSig)
defer connection.cancelCtx()
if err := connection.writer(); err != nil {
// TODO: properly handle error
}
}()
return &connection, nil
}
func (c *Connection) writePacket(packet string, immediate bool) error {
payload := writePayload{
packet: packet,
immediate: immediate,
}
select {
case <-c.ctx.Done():
return errors.New("context closed")
case c.writerChan <- payload:
return nil
}
}
func (c *Connection) WritePacket(packet string) error {
return c.writePacket(packet, AlwaysImmediate)
}
func (c *Connection) WritePacketImmediately(packet string) error {
return c.writePacket(packet, true)
}
func (c *Connection) ReadPacket() (packet string, err error) {
select {
case <-c.ctx.Done():
return "", errors.New("context closed")
case packet = <-c.readerChan:
return packet, nil
}
}
func (c *Connection) reader() error {
reader := bufio.NewReaderSize(c.conn, maxFSDPacketSize)
for {
if err := c.conn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
// TODO: properly handle error
return err
}
var packet string
if buf, err := reader.ReadSlice('\n'); err != nil {
return err
} else if len(buf) > maxFSDPacketSize {
return protocol.NewGenericFSDError(protocol.SyntaxError, "", "packet too long")
} else if !strings.HasSuffix(string(buf), protocol.PacketDelimiter) {
return protocol.NewGenericFSDError(protocol.SyntaxError, "", "invalid packet delimeter")
} else {
// Copy the buffer since ReadSlice() will overwrite it eventually.
bufCopy := make([]byte, len(buf))
copy(bufCopy, buf)
packet = string(buf)
}
// Send the packet over the channel
select {
case <-c.ctx.Done():
return nil
case c.readerChan <- packet:
}
}
}
func (c *Connection) writer() error {
writer := bufio.NewWriter(c.conn)
defer writer.Flush()
ticker := time.NewTicker(writeFlushInterval)
defer ticker.Stop()
ticker.Stop()
tickerActive := false
for {
// Read incoming packet
select {
case <-c.ctx.Done():
return nil
case payload, ok := <-c.writerChan:
if !ok {
return nil
}
if err := c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
// TODO: properly handle error
return err
}
if _, err := writer.WriteString(payload.packet); err != nil {
// TODO: properly handle error
return err
}
if payload.immediate {
if err := writer.Flush(); err != nil {
// TODO: properly handle error
return err
}
}
if !tickerActive {
tickerActive = true
ticker.Reset(writeFlushInterval)
}
case <-ticker.C:
if err := c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
// TODO: properly handle error
return err
}
if err := writer.Flush(); err != nil {
// TODO: properly handle error
return err
}
ticker.Stop()
tickerActive = false
}
}
}
// Start handles the logical lifetime of this connection
func (c *Connection) Start() {
// Ensure connection is closed when finished
defer c.conn.Close()
// Wait until connection state cleans up before closing connection
defer func() {
<-c.doneSig
}()
// Ensure context is always cancelled
defer c.cancelCtx()
// Attempt to login
var fsdClient *FSDClient
var err error
if fsdClient, err = c.attemptLogin(); err != nil {
var fsdError *protocol.FSDError
if errors.As(err, &fsdError) {
c.WritePacketImmediately(fsdError.Serialize())
}
return
}
// Register to post office
if err = servercontext.PostOffice().RegisterAddress(fsdClient); err != nil {
if errors.Is(err, postoffice.KeyInUseError) {
c.WritePacketImmediately(protocol.NewGenericFSDError(protocol.CallsignInUseError, fsdClient.callsign, "").Serialize())
}
return
}
defer servercontext.PostOffice().DeregisterAddress(fsdClient)
// Broadcast add pilot message
fsdClient.broadcastAddPilot()
defer fsdClient.broadcastDeletePilot()
if err = fsdClient.sendMOTD(); err != nil {
return
}
// Now that we're logged in, run the event loop until an error occurs
fsdClient.EventLoop()
}
// attemptLogin attempts to log in the connection
func (c *Connection) attemptLogin() (fsdClient *FSDClient, err error) {
// Generate the initial challenge
var initChallenge string
if initChallenge, err = vatsimauth.GenerateChallenge(); err != nil {
log.Println("error calling vatsimauth.GenerateChallenge(): " + err.Error())
err = protocol.NewGenericFSDError(protocol.SyntaxError, "", "internal server error (error generating initial challenge)")
return
}
// Generate server identification packet
serverIdentPDU := protocol.ServerIdentificationPDU{
From: protocol.ServerCallsign,
To: "CLIENT",
Version: servercontext.VersionIdentifier,
InitialChallenge: initChallenge,
}
serverIdentPacket := serverIdentPDU.Serialize()
// Write server identification packet
if err = c.WritePacketImmediately(serverIdentPacket); err != nil {
err = protocol.NewGenericFSDError(protocol.SyntaxError, "", "internal server error (error writing $DI server identification packet)")
return
}
// Read the first expected packet: client identification
var packet string
if packet, err = c.ReadPacket(); err != nil {
err = protocol.NewGenericFSDError(protocol.SyntaxError, "", "error reading $ID client identification packet")
return
}
// Parse it
var clientIdentPDU protocol.ClientIdentificationPDU
if err = clientIdentPDU.Parse(packet); err != nil {
return
}
// Read the second expected packet: add pilot
if packet, err = c.ReadPacket(); err != nil {
err = protocol.NewGenericFSDError(protocol.SyntaxError, "", "error reading #AP add pilot packet")
return
}
var addPilotPDU protocol.AddPilotPDU
if err = addPilotPDU.Parse(packet); err != nil {
return nil, err
}
// Handle authentication
var networkRating protocol.NetworkRating
var pilotRating protocol.PilotRating
if networkRating, pilotRating, err = handleAuthentication(addPilotPDU.CID, addPilotPDU.Token); err != nil {
return
}
// Check for valid network rating
if addPilotPDU.NetworkRating > networkRating {
err = protocol.NewGenericFSDError(protocol.RequestedLevelTooHighError, strconv.Itoa(int(networkRating)), "try again at or below your assigned rating")
return
}
// Check for disallowed callsign
switch clientIdentPDU.From {
case protocol.ServerCallsign, protocol.ClientQueryBroadcastRecipient, protocol.ClientQueryBroadcastRecipientPilots:
err = protocol.NewGenericFSDError(protocol.CallsignInvalidError, clientIdentPDU.From, "forbidden callsign")
return
}
// Verify protocol revision
if addPilotPDU.ProtocolRevision != protocol.ProtoRevisionVatsim2022 {
err = protocol.NewGenericFSDError(protocol.InvalidProtocolRevisionError, strconv.Itoa(addPilotPDU.ProtocolRevision), "please connect with a client that supports the Vatsim2022 (101) protocol revision")
return
}
// Verify if this browser is supported by vatsimauth
if _, ok := vatsimauth.Keys[clientIdentPDU.ClientID]; !ok {
err = protocol.NewGenericFSDError(protocol.UnauthorizedSoftwareError, "", "provided client ID is not supported by vatsimauth")
return
}
fsdClient = NewFSDClient(c, &clientIdentPDU, &addPilotPDU, initChallenge, pilotRating)
return fsdClient, nil
}
func handleAuthentication(cid int, tokenField string) (networkRating protocol.NetworkRating, pilotRating protocol.PilotRating, err error) {
// Check if token field actually contains a JWT token, otherwise treat as plaintext password
err = protocol.V.Var(tokenField, "required,jwt")
if err == nil {
// Treat as JWT token
return verifyJWTToken(cid, tokenField)
} else {
// Treat as plaintext password
return verifyPassword(cid, tokenField)
}
}
func verifyJWTToken(cid int, tokenField string) (networkRating protocol.NetworkRating, pilotRating protocol.PilotRating, err error) {
var token *jwt.Token
var verifier auth2.DefaultVerifier
if token, err = verifier.VerifyJWT(tokenField); err != nil {
err = protocol.NewGenericFSDError(protocol.InvalidLogonError, "", "invalid token")
return
}
claims := auth2.FSDJWTClaims{}
if err = claims.Parse(token); err != nil {
err = protocol.NewGenericFSDError(protocol.InvalidLogonError, "", "invalid token claims")
return
}
if claims.CID() != cid {
err = protocol.NewGenericFSDError(protocol.InvalidLogonError, "", "invalid token claims (CID)")
return
}
networkRating = claims.ControllerRating()
pilotRating = claims.PilotRating()
return
}
// verifyPassword verifies a password in the cases when PLAINTEXT_PASSWORDS is in use.
func verifyPassword(cid int, password string) (networkRating protocol.NetworkRating, pilotRating protocol.PilotRating, err error) {
var userRecord database.FSDUserRecord
if err = userRecord.LoadByCID(servercontext.DB(), cid); err != nil {
return -1, -1, err
}
err = bcrypt.CompareHashAndPassword([]byte(userRecord.FSDPassword), []byte(password))
if err != nil {
return -1, -1, err
}
return userRecord.NetworkRating, userRecord.PilotRating, nil
}