diff --git a/cvmts/src/WSServer.ts b/cvmts/src/CollabVMServer.ts similarity index 89% rename from cvmts/src/WSServer.ts rename to cvmts/src/CollabVMServer.ts index 6e3164f..d3dab83 100644 --- a/cvmts/src/WSServer.ts +++ b/cvmts/src/CollabVMServer.ts @@ -1,7 +1,6 @@ import { WebSocketServer, WebSocket } from 'ws'; import * as http from 'http'; import IConfig from './IConfig.js'; -import internal from 'stream'; import * as Utilities from './Utilities.js'; import { User, Rank } from './User.js'; import * as guacutils from './guacutils.js'; @@ -36,12 +35,9 @@ type VoteTally = { no: number; }; -export default class WSServer { +export default class CollabVMServer { private Config: IConfig; - private httpServer: http.Server; - private wsServer: WebSocketServer; - private clients: User[]; private ChatHistory: CircularBuffer; @@ -105,14 +101,6 @@ export default class WSServer { this.indefiniteTurn = null; this.ModPerms = Utilities.MakeModPerms(this.Config.collabvm.moderatorPermissions); - this.httpServer = http.createServer(); - this.wsServer = new WebSocketServer({ noServer: true }); - this.httpServer.on('upgrade', (req: http.IncomingMessage, socket: internal.Duplex, head: Buffer) => this.httpOnUpgrade(req, socket, head)); - this.httpServer.on('request', (req, res) => { - res.writeHead(426); - res.write('This server only accepts WebSocket connections.'); - res.end(); - }); let initSize = vm.GetDisplay().Size() || { width: 0, @@ -130,131 +118,14 @@ export default class WSServer { this.auth = auth; } - listen() { - this.httpServer.listen(this.Config.http.port, this.Config.http.host); - } - - private httpOnUpgrade(req: http.IncomingMessage, socket: internal.Duplex, head: Buffer) { - var killConnection = () => { - socket.write('HTTP/1.1 400 Bad Request\n\n400 Bad Request'); - socket.destroy(); - }; - - if (req.headers['sec-websocket-protocol'] !== 'guacamole') { - killConnection(); - return; - } - - if (this.Config.http.origin) { - // If the client is not sending an Origin header, kill the connection. - if (!req.headers.origin) { - killConnection(); - return; - } - - // Try to parse the Origin header sent by the client, if it fails, kill the connection. - var _uri; - var _host; - try { - _uri = new URL(req.headers.origin.toLowerCase()); - _host = _uri.host; - } catch { - killConnection(); - return; - } - - // detect fake origin headers - if (_uri.pathname !== '/' || _uri.search !== '') { - killConnection(); - return; - } - - // If the domain name is not in the list of allowed origins, kill the connection. - if (!this.Config.http.originAllowedDomains.includes(_host)) { - killConnection(); - return; - } - } - - let ip: string; - if (this.Config.http.proxying) { - // If the requesting IP isn't allowed to proxy, kill it - if (this.Config.http.proxyAllowedIps.indexOf(req.socket.remoteAddress!) === -1) { - killConnection(); - return; - } - // Make sure x-forwarded-for is set - if (req.headers['x-forwarded-for'] === undefined) { - killConnection(); - return; - } - try { - // Get the first IP from the X-Forwarded-For variable - ip = req.headers['x-forwarded-for']?.toString().replace(/\ /g, '').split(',')[0]; - } catch { - // If we can't get the IP, kill the connection - killConnection(); - return; - } - // If for some reason the IP isn't defined, kill it - if (!ip) { - killConnection(); - return; - } - // Make sure the IP is valid. If not, kill the connection. - if (!isIP(ip)) { - killConnection(); - return; - } - } else { - if (!req.socket.remoteAddress) return; - ip = req.socket.remoteAddress; - } - - // Get the amount of active connections coming from the requesting IP. - let connections = this.clients.filter((client) => client.IP.address == ip); - // If it exceeds the limit set in the config, reject the connection with a 429. - if (connections.length + 1 > this.Config.http.maxConnections) { - socket.write('HTTP/1.1 429 Too Many Requests\n\n429 Too Many Requests'); - socket.destroy(); - } - - this.wsServer.handleUpgrade(req, socket, head, (ws: WebSocket) => { - this.wsServer.emit('connection', ws, req); - this.onConnection(ws, req, ip); - }); - } - - private onConnection(ws: WebSocket, req: http.IncomingMessage, ip: string) { - let user = new User(ws, IPDataManager.GetIPData(ip), this.Config); + public addUser(user: User) { this.clients.push(user); - - ws.on('error', (e) => { - this.logger.Error(`${e} (caused by connection ${ip})`); - ws.close(); - }); - - ws.on('close', () => this.connectionClosed(user)); - - ws.on('message', (buf: Buffer, isBinary: boolean) => { - var msg; - - // Close the user's connection if they send a non-string message - if (isBinary) { - user.closeConnection(); - return; - } - - try { - this.onMessage(user, buf.toString()); - } catch {} - }); - + user.socket.on('msg', (msg: string) => this.onMessage(user, msg)); + user.socket.on('disconnect', () => this.connectionClosed(user)); if (this.Config.auth.enabled) { user.sendMsg(guacutils.encode('auth', this.Config.auth.apiEndpoint)); } user.sendMsg(this.getAdduserMsg()); - this.logger.Info(`Connect from ${user.IP.address}`); } private connectionClosed(user: User) { diff --git a/cvmts/src/NetworkClient.ts b/cvmts/src/NetworkClient.ts new file mode 100644 index 0000000..ae1c36e --- /dev/null +++ b/cvmts/src/NetworkClient.ts @@ -0,0 +1,8 @@ +export default interface NetworkClient { + getIP() : string; + send(msg: string) : Promise; + close() : void; + on(event: string, listener: (...args: any[]) => void) : void; + off(event: string, listener: (...args: any[]) => void) : void; + isOpen() : boolean; +} \ No newline at end of file diff --git a/cvmts/src/NetworkServer.ts b/cvmts/src/NetworkServer.ts new file mode 100644 index 0000000..fd3ec24 --- /dev/null +++ b/cvmts/src/NetworkServer.ts @@ -0,0 +1,6 @@ +export default interface NetworkServer { + start() : void; + stop() : void; + on(event: string, listener: (...args: any[]) => void) : void; + off(event: string, listener: (...args: any[]) => void) : void; +} \ No newline at end of file diff --git a/cvmts/src/User.ts b/cvmts/src/User.ts index 81cea6a..187efa2 100644 --- a/cvmts/src/User.ts +++ b/cvmts/src/User.ts @@ -1,14 +1,14 @@ import * as Utilities from './Utilities.js'; import * as guacutils from './guacutils.js'; -import { WebSocket } from 'ws'; import { IPData } from './IPData.js'; import IConfig from './IConfig.js'; import RateLimiter from './RateLimiter.js'; import { execa, execaCommand, ExecaSyncError } from 'execa'; import { Logger } from '@cvmts/shared'; +import NetworkClient from './NetworkClient.js'; export class User { - socket: WebSocket; + socket: NetworkClient; nopSendInterval: NodeJS.Timeout; msgRecieveInterval: NodeJS.Timeout; nopRecieveTimeout?: NodeJS.Timeout; @@ -28,17 +28,18 @@ export class User { private logger = new Logger('CVMTS.User'); - constructor(ws: WebSocket, ip: IPData, config: IConfig, username?: string, node?: string) { + constructor(socket: NetworkClient, ip: IPData, config: IConfig, username?: string, node?: string) { this.IP = ip; this.connectedToNode = false; this.viewMode = -1; this.Config = config; - this.socket = ws; + this.socket = socket; this.msgsSent = 0; - this.socket.on('close', () => { + this.socket.on('disconnect', () => { clearInterval(this.nopSendInterval); + clearInterval(this.msgRecieveInterval); }); - this.socket.on('message', (e) => { + this.socket.on('msg', (e) => { clearTimeout(this.nopRecieveTimeout); clearInterval(this.msgRecieveInterval); this.msgRecieveInterval = setInterval(() => this.onNoMsg(), 10000); @@ -73,8 +74,8 @@ export class User { this.socket.send('3.nop;'); } - sendMsg(msg: string | Buffer) { - if (this.socket.readyState !== this.socket.OPEN) return; + sendMsg(msg: string) { + if (!this.socket.isOpen()) return; clearInterval(this.nopSendInterval); this.nopSendInterval = setInterval(() => this.sendNop(), 5000); this.socket.send(msg); diff --git a/cvmts/src/WebSocket/WSClient.ts b/cvmts/src/WebSocket/WSClient.ts new file mode 100644 index 0000000..96b9a36 --- /dev/null +++ b/cvmts/src/WebSocket/WSClient.ts @@ -0,0 +1,55 @@ +import { WebSocket } from "ws"; +import NetworkClient from "../NetworkClient.js"; +import EventEmitter from "events"; +import { Logger } from "@cvmts/shared"; + +export default class WSClient extends EventEmitter implements NetworkClient { + socket: WebSocket; + ip: string; + logger: Logger; + + constructor(ws: WebSocket, ip: string) { + super(); + this.socket = ws; + this.ip = ip; + this.logger = new Logger("CVMTS.WSClient"); + this.socket.on('message', (buf: Buffer, isBinary: boolean) => { + // Close the user's connection if they send a non-string message + if (isBinary) { + this.close(); + return; + } + + this.emit('msg', buf.toString("utf-8")); + }); + + this.socket.on('close', () => { + this.emit('disconnect'); + + }); + } + + isOpen(): boolean { + return this.socket.readyState === WebSocket.OPEN; + } + + getIP(): string { + return this.ip; + } + send(msg: string): Promise { + return new Promise((res,rej) => { + this.socket.send(msg, (err) => { + if (err) { + rej(err); + return; + } + res(); + }); + }); + } + + close(): void { + this.socket.close(); + } + +} \ No newline at end of file diff --git a/cvmts/src/WebSocket/WSServer.ts b/cvmts/src/WebSocket/WSServer.ts new file mode 100644 index 0000000..dd37da5 --- /dev/null +++ b/cvmts/src/WebSocket/WSServer.ts @@ -0,0 +1,152 @@ +import * as http from 'http'; +import NetworkServer from '../NetworkServer.js'; +import EventEmitter from 'events'; +import { WebSocketServer, WebSocket } from 'ws'; +import internal from 'stream'; +import IConfig from '../IConfig.js'; +import { isIP } from 'net'; +import { IPDataManager } from '../IPData.js'; +import WSClient from './WSClient.js'; +import { User } from '../User.js'; +import { Logger } from '@cvmts/shared'; + +export default class WSServer extends EventEmitter implements NetworkServer { + private httpServer: http.Server; + private wsServer: WebSocketServer; + private clients: WSClient[]; + private Config: IConfig; + private logger: Logger; + + constructor(config : IConfig) { + super(); + this.Config = config; + this.clients = []; + this.logger = new Logger("CVMTS.WSServer"); + this.httpServer = http.createServer(); + this.wsServer = new WebSocketServer({ noServer: true }); + this.httpServer.on('upgrade', (req: http.IncomingMessage, socket: internal.Duplex, head: Buffer) => this.httpOnUpgrade(req, socket, head)); + this.httpServer.on('request', (req, res) => { + res.writeHead(426); + res.write('This server only accepts WebSocket connections.'); + res.end(); + }); + } + + start(): void { + this.httpServer.listen(this.Config.http.port, this.Config.http.host, () => { + this.logger.Info(`WebSocket server listening on ${this.Config.http.host}:${this.Config.http.port}`); + }); + } + + stop(): void { + this.httpServer.close(); + } + + private httpOnUpgrade(req: http.IncomingMessage, socket: internal.Duplex, head: Buffer) { + var killConnection = () => { + socket.write('HTTP/1.1 400 Bad Request\n\n400 Bad Request'); + socket.destroy(); + }; + + if (req.headers['sec-websocket-protocol'] !== 'guacamole') { + killConnection(); + return; + } + + if (this.Config.http.origin) { + // If the client is not sending an Origin header, kill the connection. + if (!req.headers.origin) { + killConnection(); + return; + } + + // Try to parse the Origin header sent by the client, if it fails, kill the connection. + var _uri; + var _host; + try { + _uri = new URL(req.headers.origin.toLowerCase()); + _host = _uri.host; + } catch { + killConnection(); + return; + } + + // detect fake origin headers + if (_uri.pathname !== '/' || _uri.search !== '') { + killConnection(); + return; + } + + // If the domain name is not in the list of allowed origins, kill the connection. + if (!this.Config.http.originAllowedDomains.includes(_host)) { + killConnection(); + return; + } + } + + let ip: string; + if (this.Config.http.proxying) { + // If the requesting IP isn't allowed to proxy, kill it + if (this.Config.http.proxyAllowedIps.indexOf(req.socket.remoteAddress!) === -1) { + killConnection(); + return; + } + // Make sure x-forwarded-for is set + if (req.headers['x-forwarded-for'] === undefined) { + killConnection(); + return; + } + try { + // Get the first IP from the X-Forwarded-For variable + ip = req.headers['x-forwarded-for']?.toString().replace(/\ /g, '').split(',')[0]; + } catch { + // If we can't get the IP, kill the connection + killConnection(); + return; + } + // If for some reason the IP isn't defined, kill it + if (!ip) { + killConnection(); + return; + } + // Make sure the IP is valid. If not, kill the connection. + if (!isIP(ip)) { + killConnection(); + return; + } + } else { + if (!req.socket.remoteAddress) return; + ip = req.socket.remoteAddress; + } + + // TODO: Implement + + // Get the amount of active connections coming from the requesting IP. + //let connections = this.clients.filter((client) => client.IP.address == ip); + // If it exceeds the limit set in the config, reject the connection with a 429. + //if (connections.length + 1 > this.Config.http.maxConnections) { + // socket.write('HTTP/1.1 429 Too Many Requests\n\n429 Too Many Requests'); + // socket.destroy(); + //} + + this.wsServer.handleUpgrade(req, socket, head, (ws: WebSocket) => { + this.wsServer.emit('connection', ws, req); + this.onConnection(ws, req, ip); + }); + } + + private onConnection(ws: WebSocket, req: http.IncomingMessage, ip: string) { + let client = new WSClient(ws, ip); + this.clients.push(client); + let user = new User(client, IPDataManager.GetIPData(ip), this.Config); + + this.emit('connect', user); + + ws.on('error', (e) => { + this.logger.Error(`${e} (caused by connection ${ip})`); + ws.close(); + }); + + this.logger.Info(`Connect from ${user.IP.address}`); + } +} \ No newline at end of file diff --git a/cvmts/src/index.ts b/cvmts/src/index.ts index 89cd218..980c984 100644 --- a/cvmts/src/index.ts +++ b/cvmts/src/index.ts @@ -1,12 +1,14 @@ import * as toml from 'toml'; import IConfig from './IConfig.js'; import * as fs from 'fs'; -import WSServer from './WSServer.js'; +import CollabVMServer from './CollabVMServer.js'; import { QemuVM, QemuVmDefinition } from '@cvmts/qemu'; import * as Shared from '@cvmts/shared'; import AuthManager from './AuthManager.js'; +import WSServer from './WebSocket/WSServer.js'; +import { User } from './User.js'; let logger = new Shared.Logger('CVMTS.Init'); @@ -50,8 +52,11 @@ async function start() { var VM = new QemuVM(def); await VM.Start(); - // Start up the websocket server - var WS = new WSServer(Config, VM, auth); - WS.listen(); + // Start up the server + var CVM = new CollabVMServer(Config, VM, auth); + + var WS = new WSServer(Config); + WS.on('connect', (client: User) => CVM.addUser(client)); + WS.start(); } start();