WebSockets.swift 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. //
  2. // HttpHandlers+WebSockets.swift
  3. // Swifter
  4. //
  5. // Copyright © 2014-2016 Damian Kołakowski. All rights reserved.
  6. //
  7. #if os(Linux)
  8. import Glibc
  9. #else
  10. import Foundation
  11. #endif
  12. public func websocket(
  13. text: ((WebSocketSession, String) -> Void)?,
  14. _ binary: ((WebSocketSession, [UInt8]) -> Void)?) -> (HttpRequest -> HttpResponse) {
  15. return { r in
  16. guard let upgradeHeader = r.headers["upgrade"] where
  17. upgradeHeader.lowercaseString == "websocket" else {
  18. print(r.headers["upgrade"])
  19. return .BadRequest(.Text("Invalid value of 'Upgrade' header: \(r.headers["upgrade"])"))
  20. }
  21. guard let connectionHeader = r.headers["connection"] where
  22. connectionHeader.lowercaseString == "upgrade" else {
  23. return .BadRequest(.Text("Invalid value of 'Connection' header: \(r.headers["connection"])"))
  24. }
  25. guard let secWebSocketKey = r.headers["sec-websocket-key"] else {
  26. return .BadRequest(.Text("Invalid value of 'Sec-Websocket-Key' header: \(r.headers["sec-websocket-key"])"))
  27. }
  28. let protocolSessionClosure: (Socket -> Void) = { socket in
  29. let session = WebSocketSession(socket)
  30. while let frame = try? session.readFrame() {
  31. switch frame.opcode {
  32. case .Text:
  33. if let handleText = text {
  34. handleText(session, String.fromUInt8(frame.payload))
  35. }
  36. case .Binary:
  37. if let handleBinary = binary {
  38. handleBinary(session, frame.payload)
  39. }
  40. default: break
  41. }
  42. }
  43. }
  44. let secWebSocketAccept = String.toBase64((secWebSocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").SHA1())
  45. let headers = [ "Upgrade": "WebSocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": secWebSocketAccept]
  46. return HttpResponse.SwitchProtocols(headers, protocolSessionClosure)
  47. }
  48. }
  49. public class WebSocketSession: Hashable, Equatable {
  50. public enum Error: ErrorType { case UnknownOpCode(String), UnMaskedFrame }
  51. public enum OpCode { case Continue, Close, Ping, Pong, Text, Binary }
  52. public class Frame {
  53. public var opcode = OpCode.Close
  54. public var fin = false
  55. public var payload = [UInt8]()
  56. }
  57. private let socket: Socket
  58. public init(_ socket: Socket) {
  59. self.socket = socket
  60. }
  61. public func writeText(text: String) -> Void {
  62. self.writeFrame(ArraySlice(text.utf8), OpCode.Text)
  63. }
  64. public func writeBinary(binary: [UInt8]) -> Void {
  65. self.writeBinary(ArraySlice(binary))
  66. }
  67. public func writeBinary(binary: ArraySlice<UInt8>) -> Void {
  68. self.writeFrame(binary, OpCode.Binary)
  69. }
  70. private func writeFrame(data: ArraySlice<UInt8>, _ op: OpCode, _ fin: Bool = true) {
  71. let finAndOpCode = encodeFinAndOpCode(fin, op: op)
  72. let maskAndLngth = encodeLengthAndMaskFlag(UInt64(data.count), false)
  73. do {
  74. try self.socket.writeUInt8([finAndOpCode])
  75. try self.socket.writeUInt8(maskAndLngth)
  76. try self.socket.writeUInt8(data)
  77. } catch {
  78. print(error)
  79. }
  80. }
  81. private func encodeFinAndOpCode(fin: Bool, op: OpCode) -> UInt8 {
  82. var encodedByte = UInt8(fin ? 0x80 : 0x00);
  83. switch op {
  84. case .Continue : encodedByte |= 0x00 & 0x0F;
  85. case .Text : encodedByte |= 0x01 & 0x0F;
  86. case .Binary : encodedByte |= 0x02 & 0x0F;
  87. case .Close : encodedByte |= 0x08 & 0x0F;
  88. case .Ping : encodedByte |= 0x09 & 0x0F;
  89. case .Pong : encodedByte |= 0x0A & 0x0F;
  90. }
  91. return encodedByte
  92. }
  93. private func encodeLengthAndMaskFlag(len: UInt64, _ masked: Bool) -> [UInt8] {
  94. let encodedLngth = UInt8(masked ? 0x80 : 0x00)
  95. var encodedBytes = [UInt8]()
  96. switch len {
  97. case 0...125:
  98. encodedBytes.append(encodedLngth | UInt8(len));
  99. case 126...UInt64(UINT16_MAX):
  100. encodedBytes.append(encodedLngth | 0x7E);
  101. encodedBytes.append(UInt8(len >> 8));
  102. encodedBytes.append(UInt8(len & 0xFF));
  103. default:
  104. encodedBytes.append(encodedLngth | 0x7F);
  105. encodedBytes.append(UInt8(len >> 56) & 0xFF);
  106. encodedBytes.append(UInt8(len >> 48) & 0xFF);
  107. encodedBytes.append(UInt8(len >> 40) & 0xFF);
  108. encodedBytes.append(UInt8(len >> 32) & 0xFF);
  109. encodedBytes.append(UInt8(len >> 24) & 0xFF);
  110. encodedBytes.append(UInt8(len >> 16) & 0xFF);
  111. encodedBytes.append(UInt8(len >> 08) & 0xFF);
  112. encodedBytes.append(UInt8(len >> 00) & 0xFF);
  113. }
  114. return encodedBytes
  115. }
  116. public func readFrame() throws -> Frame {
  117. let frm = Frame()
  118. let fst = try socket.read()
  119. frm.fin = fst & 0x80 != 0
  120. let opc = fst & 0x0F
  121. switch opc {
  122. case 0x00: frm.opcode = OpCode.Continue
  123. case 0x01: frm.opcode = OpCode.Text
  124. case 0x02: frm.opcode = OpCode.Binary
  125. case 0x08: frm.opcode = OpCode.Close
  126. case 0x09: frm.opcode = OpCode.Ping
  127. case 0x0A: frm.opcode = OpCode.Pong
  128. // "If an unknown opcode is received, the receiving endpoint MUST _Fail the WebSocket Connection_."
  129. // http://tools.ietf.org/html/rfc6455#section-5.2 ( Page 29 )
  130. default : throw Error.UnknownOpCode("\(opc)")
  131. }
  132. let sec = try socket.read()
  133. let msk = sec & 0x80 != 0
  134. guard msk else {
  135. // "...a client MUST mask all frames that it sends to the serve.."
  136. // http://tools.ietf.org/html/rfc6455#section-5.1
  137. throw Error.UnMaskedFrame
  138. }
  139. var len = UInt64(sec & 0x7F)
  140. if len == 0x7E {
  141. let b0 = UInt64(try socket.read())
  142. let b1 = UInt64(try socket.read())
  143. len = UInt64(littleEndian: b0 << 8 | b1)
  144. } else if len == 0x7F {
  145. let b0 = UInt64(try socket.read())
  146. let b1 = UInt64(try socket.read())
  147. let b2 = UInt64(try socket.read())
  148. let b3 = UInt64(try socket.read())
  149. let b4 = UInt64(try socket.read())
  150. let b5 = UInt64(try socket.read())
  151. let b6 = UInt64(try socket.read())
  152. let b7 = UInt64(try socket.read())
  153. len = UInt64(littleEndian: b0 << 54 | b1 << 48 | b2 << 40 | b3 << 32 | b4 << 24 | b5 << 16 | b6 << 8 | b7)
  154. }
  155. let mask = [try socket.read(), try socket.read(), try socket.read(), try socket.read()]
  156. for i in 0..<len {
  157. frm.payload.append(try socket.read() ^ mask[Int(i % 4)])
  158. }
  159. return frm
  160. }
  161. public var hashValue: Int {
  162. get {
  163. return socket.hashValue
  164. }
  165. }
  166. }
  167. public func ==(webSocketSession1: WebSocketSession, webSocketSession2: WebSocketSession) -> Bool {
  168. return webSocketSession1.socket == webSocketSession2.socket
  169. }