فهرست منبع

WebSocket handshake copied from master branch (firefox fix, use header token value).
Fixed #151 (Websocket cause server crash on text message with payload of length 65536).

Damian Kołakowski 10 سال پیش
والد
کامیت
e6f2270040
2فایلهای تغییر یافته به همراه25 افزوده شده و 38 حذف شده
  1. 7 0
      Sources/HttpRequest.swift
  2. 18 38
      Sources/WebSockets.swift

+ 7 - 0
Sources/HttpRequest.swift

@@ -21,6 +21,13 @@ public class HttpRequest {
     public var address: String? = ""
     public var params: [String: String] = [:]
     
+    public func hasTokenForHeader(_ headerName: String, token: String) -> Bool {
+        guard let headerValue = headers[headerName] else {
+            return false
+        }
+        return headerValue.split(",").filter({ $0.trim().lowercaseString == token }).count > 0
+    }
+    
     public func parseUrlencodedForm() -> [(String, String)] {
         guard let contentTypeHeader = headers["content-type"] else {
             return []

+ 18 - 38
Sources/WebSockets.swift

@@ -15,13 +15,10 @@ public func websocket(
         text: ((WebSocketSession, String) -> Void)?,
     _ binary: ((WebSocketSession, [UInt8]) -> Void)?) -> (HttpRequest -> HttpResponse) {
     return { r in
-        guard let upgradeHeader = r.headers["upgrade"] where
-        upgradeHeader.lowercaseString == "websocket" else {
-            print(r.headers["upgrade"])
+        guard r.hasTokenForHeader("upgrade", token: "websocket") else {
             return .BadRequest(.Text("Invalid value of 'Upgrade' header: \(r.headers["upgrade"])"))
         }
-        guard let connectionHeader = r.headers["connection"] where
-            connectionHeader.lowercaseString == "upgrade" else {
+        guard r.hasTokenForHeader("connection", token: "upgrade") else {
             return .BadRequest(.Text("Invalid value of 'Connection' header: \(r.headers["connection"])"))
         }
         guard let secWebSocketKey = r.headers["sec-websocket-key"] else {
@@ -52,7 +49,7 @@ public func websocket(
 public class WebSocketSession: Hashable, Equatable  {
     
     public enum Error: ErrorType { case UnknownOpCode(String), UnMaskedFrame }
-    public enum OpCode { case Continue, Close, Ping, Pong, Text, Binary }
+    public enum OpCode: UInt8 { case Continue = 0x00, Close = 0x08, Ping = 0x09, Pong = 0x0A, Text = 0x01, Binary = 0x02 }
     
     public class Frame {
         public var opcode = OpCode.Close
@@ -79,7 +76,7 @@ public class WebSocketSession: Hashable, Equatable  {
     }
     
     private func writeFrame(data: ArraySlice<UInt8>, _ op: OpCode, _ fin: Bool = true) {
-        let finAndOpCode = encodeFinAndOpCode(fin, op: op)
+        let finAndOpCode = UInt8(fin ? 0x80 : 0x00) | op.rawValue
         let maskAndLngth = encodeLengthAndMaskFlag(UInt64(data.count), false)
         do {
             try self.socket.writeUInt8([finAndOpCode])
@@ -90,19 +87,6 @@ public class WebSocketSession: Hashable, Equatable  {
         }
     }
     
-    private func encodeFinAndOpCode(fin: Bool, op: OpCode) -> UInt8 {
-        var encodedByte = UInt8(fin ? 0x80 : 0x00);
-        switch op {
-        case .Continue : encodedByte |= 0x00 & 0x0F;
-        case .Text     : encodedByte |= 0x01 & 0x0F;
-        case .Binary   : encodedByte |= 0x02 & 0x0F;
-        case .Close    : encodedByte |= 0x08 & 0x0F;
-        case .Ping     : encodedByte |= 0x09 & 0x0F;
-        case .Pong     : encodedByte |= 0x0A & 0x0F;
-        }
-        return encodedByte
-    }
-    
     private func encodeLengthAndMaskFlag(len: UInt64, _ masked: Bool) -> [UInt8] {
         let encodedLngth = UInt8(masked ? 0x80 : 0x00)
         var encodedBytes = [UInt8]()
@@ -111,18 +95,18 @@ public class WebSocketSession: Hashable, Equatable  {
             encodedBytes.append(encodedLngth | UInt8(len));
         case 126...UInt64(UINT16_MAX):
             encodedBytes.append(encodedLngth | 0x7E);
-            encodedBytes.append(UInt8(len >> 8));
-            encodedBytes.append(UInt8(len & 0xFF));
+            encodedBytes.append(UInt8(len >> 8 & 0xFF));
+            encodedBytes.append(UInt8(len >> 0 & 0xFF));
         default:
             encodedBytes.append(encodedLngth | 0x7F);
-            encodedBytes.append(UInt8(len >> 56) & 0xFF);
-            encodedBytes.append(UInt8(len >> 48) & 0xFF);
-            encodedBytes.append(UInt8(len >> 40) & 0xFF);
-            encodedBytes.append(UInt8(len >> 32) & 0xFF);
-            encodedBytes.append(UInt8(len >> 24) & 0xFF);
-            encodedBytes.append(UInt8(len >> 16) & 0xFF);
-            encodedBytes.append(UInt8(len >> 08) & 0xFF);
-            encodedBytes.append(UInt8(len >> 00) & 0xFF);
+            encodedBytes.append(UInt8(len >> 56 & 0xFF));
+            encodedBytes.append(UInt8(len >> 48 & 0xFF));
+            encodedBytes.append(UInt8(len >> 40 & 0xFF));
+            encodedBytes.append(UInt8(len >> 32 & 0xFF));
+            encodedBytes.append(UInt8(len >> 24 & 0xFF));
+            encodedBytes.append(UInt8(len >> 16 & 0xFF));
+            encodedBytes.append(UInt8(len >> 08 & 0xFF));
+            encodedBytes.append(UInt8(len >> 00 & 0xFF));
         }
         return encodedBytes
     }
@@ -132,17 +116,12 @@ public class WebSocketSession: Hashable, Equatable  {
         let fst = try socket.read()
         frm.fin = fst & 0x80 != 0
         let opc = fst & 0x0F
-        switch opc {
-            case 0x00: frm.opcode = OpCode.Continue
-            case 0x01: frm.opcode = OpCode.Text
-            case 0x02: frm.opcode = OpCode.Binary
-            case 0x08: frm.opcode = OpCode.Close
-            case 0x09: frm.opcode = OpCode.Ping
-            case 0x0A: frm.opcode = OpCode.Pong
+        guard let opcode = OpCode(rawValue: opc) else {
             // "If an unknown opcode is received, the receiving endpoint MUST _Fail the WebSocket Connection_."
             // http://tools.ietf.org/html/rfc6455#section-5.2 ( Page 29 )
-            default  : throw Error.UnknownOpCode("\(opc)")
+            throw Error.UnknownOpCode("\(opc)")
         }
+        frm.opcode = opcode
         let sec = try socket.read()
         let msk = sec & 0x80 != 0
         guard msk else {
@@ -166,6 +145,7 @@ public class WebSocketSession: Hashable, Equatable  {
             let b7 = UInt64(try socket.read())
             len = UInt64(littleEndian: b0 << 54 | b1 << 48 | b2 << 40 | b3 << 32 | b4 << 24 | b5 << 16 | b6 << 8 | b7)
         }
+        print(len)
         let mask = [try socket.read(), try socket.read(), try socket.read(), try socket.read()]
         for i in 0..<len {
             frm.payload.append(try socket.read() ^ mask[Int(i % 4)])