WebSockets.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. //
  2. // HttpHandlers+WebSockets.swift
  3. // Swifter
  4. //
  5. // Copyright © 2014-2016 Damian Kołakowski. All rights reserved.
  6. //
  7. #if os(Linux)
  8. import Glibc
  9. #else
  10. import Foundation
  11. #endif
  12. public func websocket(
  13. text: ((WebSocketSession, String) -> Void)?,
  14. _ binary: ((WebSocketSession, [UInt8]) -> Void)?) -> (HttpRequest -> HttpResponse) {
  15. return { r in
  16. guard r.hasTokenForHeader("upgrade", token: "websocket") else {
  17. return .BadRequest(.Text("Invalid value of 'Upgrade' header: \(r.headers["upgrade"])"))
  18. }
  19. guard r.hasTokenForHeader("connection", token: "upgrade") else {
  20. return .BadRequest(.Text("Invalid value of 'Connection' header: \(r.headers["connection"])"))
  21. }
  22. guard let secWebSocketKey = r.headers["sec-websocket-key"] else {
  23. return .BadRequest(.Text("Invalid value of 'Sec-Websocket-Key' header: \(r.headers["sec-websocket-key"])"))
  24. }
  25. let protocolSessionClosure: (Socket -> Void) = { socket in
  26. let session = WebSocketSession(socket)
  27. var fragmentedOpCode = WebSocketSession.OpCode.Close
  28. var payload = [UInt8]() // Used for fragmented frames.
  29. func handleTextPayload(frame: WebSocketSession.Frame) throws {
  30. if let handleText = text {
  31. if frame.fin {
  32. if payload.count > 0 {
  33. throw WebSocketSession.Error.ProtocolError("Continuing fragmented frame cannot have an operation code.")
  34. }
  35. var textFramePayload = frame.payload.map { Int8(bitPattern: $0) }
  36. textFramePayload.append(0)
  37. if let text = String(UTF8String: textFramePayload) {
  38. handleText(session, text)
  39. } else {
  40. throw WebSocketSession.Error.InvalidUTF8("")
  41. }
  42. } else {
  43. payload.appendContentsOf(frame.payload)
  44. fragmentedOpCode = .Text
  45. }
  46. }
  47. }
  48. func handleBinaryPayload(frame: WebSocketSession.Frame) throws {
  49. if let handleBinary = binary {
  50. if frame.fin {
  51. if payload.count > 0 {
  52. throw WebSocketSession.Error.ProtocolError("Continuing fragmented frame cannot have an operation code.")
  53. }
  54. handleBinary(session, frame.payload)
  55. } else {
  56. payload.appendContentsOf(frame.payload)
  57. fragmentedOpCode = .Binary
  58. }
  59. }
  60. }
  61. func handleOperationCode(frame: WebSocketSession.Frame) throws {
  62. switch frame.opcode {
  63. case .Continue:
  64. // There is no message to continue, failed immediatelly.
  65. if fragmentedOpCode == .Close {
  66. socket.shutdwn()
  67. }
  68. frame.opcode = fragmentedOpCode
  69. if frame.fin {
  70. payload.appendContentsOf(frame.payload)
  71. frame.payload = payload
  72. // Clean the buffer.
  73. payload = []
  74. // Reset the OpCode.
  75. fragmentedOpCode = WebSocketSession.OpCode.Close
  76. }
  77. try handleOperationCode(frame)
  78. case .Text:
  79. try handleTextPayload(frame)
  80. case .Binary:
  81. try handleBinaryPayload(frame)
  82. case .Close:
  83. throw WebSocketSession.Control.Close
  84. case .Ping:
  85. if frame.payload.count > 125 {
  86. throw WebSocketSession.Error.ProtocolError("Payload gretter than 125 octets.")
  87. } else {
  88. session.writeFrame(ArraySlice(frame.payload), .Pong)
  89. }
  90. case .Pong:
  91. break
  92. }
  93. }
  94. do {
  95. while true {
  96. let frame = try session.readFrame()
  97. try handleOperationCode(frame)
  98. }
  99. } catch let error {
  100. switch error {
  101. case WebSocketSession.Control.Close:
  102. // Normal close
  103. break
  104. case WebSocketSession.Error.UnknownOpCode:
  105. print("Unknown Op Code: \(error)")
  106. case WebSocketSession.Error.UnMaskedFrame:
  107. print("Unmasked frame: \(error)")
  108. case WebSocketSession.Error.InvalidUTF8:
  109. print("Invalid UTF8 character: \(error)")
  110. case WebSocketSession.Error.ProtocolError:
  111. print("Protocol error: \(error)")
  112. default:
  113. print("Unkown error \(error)")
  114. }
  115. // If an error occurs, send the close handshake.
  116. session.writeCloseFrame()
  117. }
  118. }
  119. let secWebSocketAccept = String.toBase64((secWebSocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").SHA1())
  120. let headers = [ "Upgrade": "WebSocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": secWebSocketAccept]
  121. return HttpResponse.SwitchProtocols(headers, protocolSessionClosure)
  122. }
  123. }
  124. public class WebSocketSession: Hashable, Equatable {
  125. public enum Error: ErrorType { case UnknownOpCode(String), UnMaskedFrame(String), ProtocolError(String), InvalidUTF8(String) }
  126. public enum OpCode: UInt8 { case Continue = 0x00, Close = 0x08, Ping = 0x09, Pong = 0x0A, Text = 0x01, Binary = 0x02 }
  127. public enum Control: ErrorType { case Close }
  128. public class Frame {
  129. public var opcode = OpCode.Close
  130. public var fin = false
  131. public var rsv1: UInt8 = 0
  132. public var rsv2: UInt8 = 0
  133. public var rsv3: UInt8 = 0
  134. public var payload = [UInt8]()
  135. }
  136. private let socket: Socket
  137. public init(_ socket: Socket) {
  138. self.socket = socket
  139. }
  140. deinit {
  141. writeCloseFrame()
  142. socket.shutdwn()
  143. }
  144. public func writeText(text: String) -> Void {
  145. self.writeFrame(ArraySlice(text.utf8), OpCode.Text)
  146. }
  147. public func writeBinary(binary: [UInt8]) -> Void {
  148. self.writeBinary(ArraySlice(binary))
  149. }
  150. public func writeBinary(binary: ArraySlice<UInt8>) -> Void {
  151. self.writeFrame(binary, OpCode.Binary)
  152. }
  153. private func writeFrame(data: ArraySlice<UInt8>, _ op: OpCode, _ fin: Bool = true) {
  154. let finAndOpCode = UInt8(fin ? 0x80 : 0x00) | op.rawValue
  155. let maskAndLngth = encodeLengthAndMaskFlag(UInt64(data.count), false)
  156. do {
  157. try self.socket.writeUInt8([finAndOpCode])
  158. try self.socket.writeUInt8(maskAndLngth)
  159. try self.socket.writeUInt8(data)
  160. } catch {
  161. print(error)
  162. }
  163. }
  164. private func writeCloseFrame() {
  165. writeFrame(ArraySlice("".utf8), .Close)
  166. }
  167. private func encodeLengthAndMaskFlag(len: UInt64, _ masked: Bool) -> [UInt8] {
  168. let encodedLngth = UInt8(masked ? 0x80 : 0x00)
  169. var encodedBytes = [UInt8]()
  170. switch len {
  171. case 0...125:
  172. encodedBytes.append(encodedLngth | UInt8(len));
  173. case 126...UInt64(UINT16_MAX):
  174. encodedBytes.append(encodedLngth | 0x7E);
  175. encodedBytes.append(UInt8(len >> 8 & 0xFF));
  176. encodedBytes.append(UInt8(len >> 0 & 0xFF));
  177. default:
  178. encodedBytes.append(encodedLngth | 0x7F);
  179. encodedBytes.append(UInt8(len >> 56 & 0xFF));
  180. encodedBytes.append(UInt8(len >> 48 & 0xFF));
  181. encodedBytes.append(UInt8(len >> 40 & 0xFF));
  182. encodedBytes.append(UInt8(len >> 32 & 0xFF));
  183. encodedBytes.append(UInt8(len >> 24 & 0xFF));
  184. encodedBytes.append(UInt8(len >> 16 & 0xFF));
  185. encodedBytes.append(UInt8(len >> 08 & 0xFF));
  186. encodedBytes.append(UInt8(len >> 00 & 0xFF));
  187. }
  188. return encodedBytes
  189. }
  190. public func readFrame() throws -> Frame {
  191. let frm = Frame()
  192. let fst = try socket.read()
  193. frm.fin = fst & 0x80 != 0
  194. frm.rsv1 = fst & 0x40
  195. frm.rsv2 = fst & 0x20
  196. frm.rsv3 = fst & 0x10
  197. guard frm.rsv1 == 0 && frm.rsv2 == 0 && frm.rsv3 == 0
  198. else {
  199. throw Error.ProtocolError("Reserved frame bit has not been negocitated.")
  200. }
  201. let opc = fst & 0x0F
  202. guard let opcode = OpCode(rawValue: opc) else {
  203. // "If an unknown opcode is received, the receiving endpoint MUST _Fail the WebSocket Connection_."
  204. // http://tools.ietf.org/html/rfc6455#section-5.2 ( Page 29 )
  205. throw Error.UnknownOpCode("\(opc)")
  206. }
  207. if frm.fin == false {
  208. switch opcode {
  209. case .Ping, .Pong, .Close:
  210. // Control frames must not be fragmented
  211. // https://tools.ietf.org/html/rfc6455#section-5.5 ( Page 35 )
  212. throw Error.ProtocolError("Control frames must not be framgemted.")
  213. default:
  214. break
  215. }
  216. }
  217. frm.opcode = opcode
  218. let sec = try socket.read()
  219. let msk = sec & 0x80 != 0
  220. guard msk else {
  221. // "...a client MUST mask all frames that it sends to the server."
  222. // http://tools.ietf.org/html/rfc6455#section-5.1
  223. throw Error.UnMaskedFrame("A client must mask all frames that it sends to the server.")
  224. }
  225. var len = UInt64(sec & 0x7F)
  226. if len == 0x7E {
  227. let b0 = UInt64(try socket.read())
  228. let b1 = UInt64(try socket.read())
  229. len = UInt64(littleEndian: b0 << 8 | b1)
  230. } else if len == 0x7F {
  231. let b0 = UInt64(try socket.read())
  232. let b1 = UInt64(try socket.read())
  233. let b2 = UInt64(try socket.read())
  234. let b3 = UInt64(try socket.read())
  235. let b4 = UInt64(try socket.read())
  236. let b5 = UInt64(try socket.read())
  237. let b6 = UInt64(try socket.read())
  238. let b7 = UInt64(try socket.read())
  239. len = UInt64(littleEndian: b0 << 54 | b1 << 48 | b2 << 40 | b3 << 32 | b4 << 24 | b5 << 16 | b6 << 8 | b7)
  240. }
  241. let mask = [try socket.read(), try socket.read(), try socket.read(), try socket.read()]
  242. for i in 0..<len {
  243. frm.payload.append(try socket.read() ^ mask[Int(i % 4)])
  244. }
  245. return frm
  246. }
  247. public var hashValue: Int {
  248. get {
  249. return socket.hashValue
  250. }
  251. }
  252. }
  253. public func ==(webSocketSession1: WebSocketSession, webSocketSession2: WebSocketSession) -> Bool {
  254. return webSocketSession1.socket == webSocketSession2.socket
  255. }