1
0

WebSockets.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. //
  2. // HttpHandlers+WebSockets.swift
  3. // Swifter
  4. //
  5. // Copyright © 2014-2016 Damian Kołakowski. All rights reserved.
  6. //
  7. import Foundation
  8. @available(*, deprecated, message: "Use websocket(text:binary:pong:connected:disconnected:) instead.")
  9. public func websocket(_ text: @escaping (WebSocketSession, String) -> Void,
  10. _ binary: @escaping (WebSocketSession, [UInt8]) -> Void,
  11. _ pong: @escaping (WebSocketSession, [UInt8]) -> Void) -> ((HttpRequest) -> HttpResponse) {
  12. return websocket(text: text, binary: binary, pong: pong)
  13. }
  14. // swiftlint:disable function_body_length
  15. public func websocket(
  16. text: ((WebSocketSession, String) -> Void)? = nil,
  17. binary: ((WebSocketSession, [UInt8]) -> Void)? = nil,
  18. pong: ((WebSocketSession, [UInt8]) -> Void)? = nil,
  19. connected: ((WebSocketSession) -> Void)? = nil,
  20. disconnected: ((WebSocketSession) -> Void)? = nil) -> ((HttpRequest) -> HttpResponse) {
  21. return { request in
  22. guard request.hasTokenForHeader("upgrade", token: "websocket") else {
  23. return .badRequest(.text("Invalid value of 'Upgrade' header: \(request.headers["upgrade"] ?? "unknown")"))
  24. }
  25. guard request.hasTokenForHeader("connection", token: "upgrade") else {
  26. return .badRequest(.text("Invalid value of 'Connection' header: \(request.headers["connection"] ?? "unknown")"))
  27. }
  28. guard let secWebSocketKey = request.headers["sec-websocket-key"] else {
  29. return .badRequest(.text("Invalid value of 'Sec-Websocket-Key' header: \(request.headers["sec-websocket-key"] ?? "unknown")"))
  30. }
  31. let protocolSessionClosure: ((Socket) -> Void) = { socket in
  32. let session = WebSocketSession(socket)
  33. var fragmentedOpCode = WebSocketSession.OpCode.close
  34. var payload = [UInt8]() // Used for fragmented frames.
  35. func handleTextPayload(_ frame: WebSocketSession.Frame) throws {
  36. if let handleText = text {
  37. if frame.fin {
  38. if payload.count > 0 {
  39. throw WebSocketSession.WsError.protocolError("Continuing fragmented frame cannot have an operation code.")
  40. }
  41. var textFramePayload = frame.payload.map { Int8(bitPattern: $0) }
  42. textFramePayload.append(0)
  43. if let text = String(validatingUTF8: textFramePayload) {
  44. handleText(session, text)
  45. } else {
  46. throw WebSocketSession.WsError.invalidUTF8("")
  47. }
  48. } else {
  49. payload.append(contentsOf: frame.payload)
  50. fragmentedOpCode = .text
  51. }
  52. }
  53. }
  54. func handleBinaryPayload(_ frame: WebSocketSession.Frame) throws {
  55. if let handleBinary = binary {
  56. if frame.fin {
  57. if payload.count > 0 {
  58. throw WebSocketSession.WsError.protocolError("Continuing fragmented frame cannot have an operation code.")
  59. }
  60. handleBinary(session, frame.payload)
  61. } else {
  62. payload.append(contentsOf: frame.payload)
  63. fragmentedOpCode = .binary
  64. }
  65. }
  66. }
  67. func handleOperationCode(_ frame: WebSocketSession.Frame) throws {
  68. switch frame.opcode {
  69. case .continue:
  70. // There is no message to continue, failed immediatelly.
  71. if fragmentedOpCode == .close {
  72. socket.close()
  73. }
  74. frame.opcode = fragmentedOpCode
  75. if frame.fin {
  76. payload.append(contentsOf: frame.payload)
  77. frame.payload = payload
  78. // Clean the buffer.
  79. payload = []
  80. // Reset the OpCode.
  81. fragmentedOpCode = WebSocketSession.OpCode.close
  82. }
  83. try handleOperationCode(frame)
  84. case .text:
  85. try handleTextPayload(frame)
  86. case .binary:
  87. try handleBinaryPayload(frame)
  88. case .close:
  89. throw WebSocketSession.Control.close
  90. case .ping:
  91. if frame.payload.count > 125 {
  92. throw WebSocketSession.WsError.protocolError("Payload gretter than 125 octets.")
  93. } else {
  94. session.writeFrame(ArraySlice(frame.payload), .pong)
  95. }
  96. case .pong:
  97. if let handlePong = pong {
  98. handlePong(session, frame.payload)
  99. }
  100. }
  101. }
  102. func read() throws {
  103. while true {
  104. let frame = try session.readFrame()
  105. try handleOperationCode(frame)
  106. }
  107. }
  108. connected?(session)
  109. do {
  110. try read()
  111. } catch let error {
  112. switch error {
  113. case WebSocketSession.Control.close:
  114. // Normal close
  115. break
  116. case WebSocketSession.WsError.unknownOpCode:
  117. print("Unknown Op Code: \(error)")
  118. case WebSocketSession.WsError.unMaskedFrame:
  119. print("Unmasked frame: \(error)")
  120. case WebSocketSession.WsError.invalidUTF8:
  121. print("Invalid UTF8 character: \(error)")
  122. case WebSocketSession.WsError.protocolError:
  123. print("Protocol error: \(error)")
  124. default:
  125. print("Unkown error \(error)")
  126. }
  127. // If an error occurs, send the close handshake.
  128. session.writeCloseFrame()
  129. }
  130. disconnected?(session)
  131. }
  132. guard let secWebSocketAccept = String.toBase64((secWebSocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").sha1()) else {
  133. return HttpResponse.internalServerError
  134. }
  135. let headers = ["Upgrade": "WebSocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": secWebSocketAccept]
  136. return HttpResponse.switchProtocols(headers, protocolSessionClosure)
  137. }
  138. }
  139. public class WebSocketSession: Hashable, Equatable {
  140. public enum WsError: Error { case unknownOpCode(String), unMaskedFrame(String), protocolError(String), invalidUTF8(String) }
  141. public enum OpCode: UInt8 { case `continue` = 0x00, close = 0x08, ping = 0x09, pong = 0x0A, text = 0x01, binary = 0x02 }
  142. public enum Control: Error { case close }
  143. public class Frame {
  144. public var opcode = OpCode.close
  145. public var fin = false
  146. public var rsv1: UInt8 = 0
  147. public var rsv2: UInt8 = 0
  148. public var rsv3: UInt8 = 0
  149. public var payload = [UInt8]()
  150. }
  151. public let socket: Socket
  152. public init(_ socket: Socket) {
  153. self.socket = socket
  154. }
  155. deinit {
  156. writeCloseFrame()
  157. socket.close()
  158. }
  159. public func writeText(_ text: String) {
  160. self.writeFrame(ArraySlice(text.utf8), OpCode.text)
  161. }
  162. public func writeBinary(_ binary: [UInt8]) {
  163. self.writeBinary(ArraySlice(binary))
  164. }
  165. public func writeBinary(_ binary: ArraySlice<UInt8>) {
  166. self.writeFrame(binary, OpCode.binary)
  167. }
  168. public func writeFrame(_ data: ArraySlice<UInt8>, _ op: OpCode, _ fin: Bool = true) {
  169. let finAndOpCode = UInt8(fin ? 0x80 : 0x00) | op.rawValue
  170. let maskAndLngth = encodeLengthAndMaskFlag(UInt64(data.count), false)
  171. do {
  172. try self.socket.writeUInt8([finAndOpCode])
  173. try self.socket.writeUInt8(maskAndLngth)
  174. try self.socket.writeUInt8(data)
  175. } catch {
  176. print(error)
  177. }
  178. }
  179. public func writeCloseFrame() {
  180. writeFrame(ArraySlice("".utf8), .close)
  181. }
  182. private func encodeLengthAndMaskFlag(_ len: UInt64, _ masked: Bool) -> [UInt8] {
  183. let encodedLngth = UInt8(masked ? 0x80 : 0x00)
  184. var encodedBytes = [UInt8]()
  185. switch len {
  186. case 0...125:
  187. encodedBytes.append(encodedLngth | UInt8(len))
  188. case 126...UInt64(UINT16_MAX):
  189. encodedBytes.append(encodedLngth | 0x7E)
  190. encodedBytes.append(UInt8(len >> 8 & 0xFF))
  191. encodedBytes.append(UInt8(len >> 0 & 0xFF))
  192. default:
  193. encodedBytes.append(encodedLngth | 0x7F)
  194. encodedBytes.append(UInt8(len >> 56 & 0xFF))
  195. encodedBytes.append(UInt8(len >> 48 & 0xFF))
  196. encodedBytes.append(UInt8(len >> 40 & 0xFF))
  197. encodedBytes.append(UInt8(len >> 32 & 0xFF))
  198. encodedBytes.append(UInt8(len >> 24 & 0xFF))
  199. encodedBytes.append(UInt8(len >> 16 & 0xFF))
  200. encodedBytes.append(UInt8(len >> 08 & 0xFF))
  201. encodedBytes.append(UInt8(len >> 00 & 0xFF))
  202. }
  203. return encodedBytes
  204. }
  205. // swiftlint:disable function_body_length
  206. public func readFrame() throws -> Frame {
  207. let frm = Frame()
  208. let fst = try socket.read()
  209. frm.fin = fst & 0x80 != 0
  210. frm.rsv1 = fst & 0x40
  211. frm.rsv2 = fst & 0x20
  212. frm.rsv3 = fst & 0x10
  213. guard frm.rsv1 == 0 && frm.rsv2 == 0 && frm.rsv3 == 0
  214. else {
  215. throw WsError.protocolError("Reserved frame bit has not been negocitated.")
  216. }
  217. let opc = fst & 0x0F
  218. guard let opcode = OpCode(rawValue: opc) else {
  219. // "If an unknown opcode is received, the receiving endpoint MUST _Fail the WebSocket Connection_."
  220. // http://tools.ietf.org/html/rfc6455#section-5.2 ( Page 29 )
  221. throw WsError.unknownOpCode("\(opc)")
  222. }
  223. if frm.fin == false {
  224. switch opcode {
  225. case .ping, .pong, .close:
  226. // Control frames must not be fragmented
  227. // https://tools.ietf.org/html/rfc6455#section-5.5 ( Page 35 )
  228. throw WsError.protocolError("Control frames must not be fragmented.")
  229. default:
  230. break
  231. }
  232. }
  233. frm.opcode = opcode
  234. let sec = try socket.read()
  235. let msk = sec & 0x80 != 0
  236. guard msk else {
  237. // "...a client MUST mask all frames that it sends to the server."
  238. // http://tools.ietf.org/html/rfc6455#section-5.1
  239. throw WsError.unMaskedFrame("A client must mask all frames that it sends to the server.")
  240. }
  241. var len = UInt64(sec & 0x7F)
  242. if len == 0x7E {
  243. let b0 = UInt64(try socket.read()) << 8
  244. let b1 = UInt64(try socket.read())
  245. len = UInt64(littleEndian: b0 | b1)
  246. } else if len == 0x7F {
  247. let b0 = UInt64(try socket.read()) << 54
  248. let b1 = UInt64(try socket.read()) << 48
  249. let b2 = UInt64(try socket.read()) << 40
  250. let b3 = UInt64(try socket.read()) << 32
  251. let b4 = UInt64(try socket.read()) << 24
  252. let b5 = UInt64(try socket.read()) << 16
  253. let b6 = UInt64(try socket.read()) << 8
  254. let b7 = UInt64(try socket.read())
  255. len = UInt64(littleEndian: b0 | b1 | b2 | b3 | b4 | b5 | b6 | b7)
  256. }
  257. let mask = [try socket.read(), try socket.read(), try socket.read(), try socket.read()]
  258. //Read payload all at once, then apply mask (calling `socket.read` byte-by-byte is super slow).
  259. frm.payload = try socket.read(length: Int(len))
  260. for index in 0..<len {
  261. frm.payload[Int(index)] ^= mask[Int(index % 4)]
  262. }
  263. return frm
  264. }
  265. public func hash(into hasher: inout Hasher) {
  266. hasher.combine(socket)
  267. }
  268. }
  269. public func == (webSocketSession1: WebSocketSession, webSocketSession2: WebSocketSession) -> Bool {
  270. return webSocketSession1.socket == webSocketSession2.socket
  271. }