HttpHandlers+WebSockets.swift 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. //
  2. // HttpHandlers+WebSockets.swift
  3. // Swifter
  4. //
  5. // Copyright © 2014-2016 Damian Kołakowski. All rights reserved.
  6. //
  7. import Foundation
  8. extension HttpHandlers {
  9. public class func websocket(text:(String -> Void)?, _ binary:([UInt8] -> Void)?) -> (HttpRequest -> HttpResponse) {
  10. return { r in
  11. guard r.headers["upgrade"] == "websocket" else {
  12. return .BadRequest(.Text("Invalid value of 'Upgrade' header: \(r.headers["upgrade"])"))
  13. }
  14. guard r.headers["connection"] == "Upgrade" else {
  15. return .BadRequest(.Text("Invalid value of 'Connection' header: \(r.headers["connection"])"))
  16. }
  17. guard let secWebSocketKey = r.headers["sec-websocket-key"] else {
  18. return .BadRequest(.Text("Invalid value of 'Sec-Websocket-Key' header: \(r.headers["sec-websocket-key"])"))
  19. }
  20. let protocolSessionClosure: (Socket -> Void) = { socket in
  21. let session = WebSocketSession(socket)
  22. while let frame = try? session.readFrame(socket) {
  23. switch frame.opcode {
  24. case .TEXT:
  25. if let handleText = text {
  26. handleText(String.fromUInt8(frame.payload))
  27. }
  28. case .BINARY:
  29. if let handleBinary = binary {
  30. handleBinary(frame.payload)
  31. }
  32. default: break
  33. }
  34. }
  35. }
  36. let secWebSocketAccept = String.encodeToBase64((secWebSocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").SHA1())
  37. let headers = [ "Upgrade": "WebSocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": secWebSocketAccept]
  38. return HttpResponse.SwitchProtocols(headers, protocolSessionClosure)
  39. }
  40. }
  41. public class WebSocketSession {
  42. public enum Error: ErrorType { case UnknownOpCode(String), UnMaskedFrame }
  43. public enum OpCode { case CONTINUE, CLOSE, PING, PONG, TEXT, BINARY }
  44. public class Frame {
  45. public var opcode = OpCode.CLOSE
  46. public var fin = false
  47. public var payload = [UInt8]()
  48. }
  49. private let socket: Socket
  50. init(_ socket: Socket) {
  51. self.socket = socket
  52. }
  53. public func writeText(text: String) -> Void {
  54. }
  55. public func writeBinary(binary: [UInt8]) -> Void {
  56. }
  57. public func readFrame(socket: Socket) throws -> Frame {
  58. let frm = Frame()
  59. let fst = try socket.read()
  60. frm.fin = fst & 0x80 != 0
  61. let opc = fst & 0x0F
  62. switch opc {
  63. case 0x00: frm.opcode = OpCode.CONTINUE
  64. case 0x01: frm.opcode = OpCode.TEXT
  65. case 0x02: frm.opcode = OpCode.BINARY
  66. case 0x08: frm.opcode = OpCode.CLOSE
  67. case 0x09: frm.opcode = OpCode.PING
  68. case 0x0A: frm.opcode = OpCode.PONG
  69. // "If an unknown opcode is received, the receiving endpoint MUST _Fail the WebSocket Connection_."
  70. // http://tools.ietf.org/html/rfc6455#section-5.2 ( Page 29 )
  71. default : throw Error.UnknownOpCode("\(opc)")
  72. }
  73. let sec = try socket.read()
  74. let msk = sec & 0x0F != 0
  75. guard msk else {
  76. // "...a client MUST mask all frames that it sends to the serve.."
  77. // http://tools.ietf.org/html/rfc6455#section-5.1
  78. throw Error.UnMaskedFrame
  79. }
  80. var len = UInt64(sec & 0x7F)
  81. if len == 0x7E {
  82. let b0 = UInt64(try socket.read())
  83. let b1 = UInt64(try socket.read())
  84. len = UInt64(littleEndian: b0 << 8 | b1)
  85. } else if len == 0x7F {
  86. let b0 = UInt64(try socket.read())
  87. let b1 = UInt64(try socket.read())
  88. let b2 = UInt64(try socket.read())
  89. let b3 = UInt64(try socket.read())
  90. let b4 = UInt64(try socket.read())
  91. let b5 = UInt64(try socket.read())
  92. let b6 = UInt64(try socket.read())
  93. let b7 = UInt64(try socket.read())
  94. len = UInt64(littleEndian: b0 << 54 | b1 << 48 | b2 << 40 | b3 << 32 | b4 << 24 | b5 << 16 | b6 << 8 | b7)
  95. }
  96. let mask = [try socket.read(), try socket.read(), try socket.read(), try socket.read()]
  97. for i in 0..<len {
  98. frm.payload.append(try socket.read() ^ mask[Int(i % 4)])
  99. }
  100. return frm
  101. }
  102. }
  103. }