1
0

Socket+TLS.swift 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. //
  2. // Socket+TLS.swift
  3. // Swifter
  4. //
  5. // Copyright © 2016 Damian Kołakowski. All rights reserved.
  6. //
  7. import Foundation
  8. public enum TLSError: Error {
  9. case UnknownTLSRecordType(String)
  10. case UnknownHandshakeType(String)
  11. case InvalidData(String)
  12. }
  13. public protocol HasBigEndian {
  14. var bigEndian: Self { get }
  15. }
  16. public func nextBytes(_ socket: Socket, _ n: Int) throws -> [UInt8] {
  17. var result = [UInt8]()
  18. for _ in 0..<n {
  19. result.append(try socket.read())
  20. }
  21. return result
  22. }
  23. public func nextGeneric2<T: HasBigEndian>(_ socket: Socket) throws -> T {
  24. return try nextBytes(socket, MemoryLayout<T>.size).withUnsafeBufferPointer() { UnsafePointer<T>(OpaquePointer($0.baseAddress!)).pointee }.bigEndian
  25. }
  26. public func nextUInt16(_ socket: Socket) throws -> UInt16 {
  27. return try nextBytes(socket, MemoryLayout<UInt16>.size).withUnsafeBufferPointer() { UnsafePointer<UInt16>(OpaquePointer($0.baseAddress!)).pointee }.bigEndian
  28. }
  29. public struct DataIterator {
  30. private var iterator: IndexingIterator<ArraySlice<UInt8>>
  31. public init(_ slice: ArraySlice<UInt8>) {
  32. self.iterator = slice.makeIterator()
  33. }
  34. public mutating func next(_ n: Int) -> [UInt8]? {
  35. var result = [UInt8]()
  36. for _ in 0..<n {
  37. guard let nextByte = self.iterator.next() else {
  38. return nil
  39. }
  40. result.append(nextByte)
  41. }
  42. return result
  43. }
  44. public mutating func nextByte() -> UInt8? {
  45. return self.iterator.next()
  46. }
  47. public mutating func nextUInt16() -> UInt16? {
  48. return next(MemoryLayout<UInt16>.size)?.withUnsafeBufferPointer() { UnsafePointer<UInt16>(OpaquePointer($0.baseAddress!)).pointee }.bigEndian
  49. }
  50. }
  51. extension Socket {
  52. public func acceptTLSClientSocket() throws -> Socket {
  53. let socket = try self.acceptClientSocket()
  54. let record = try readRecord(socket)
  55. switch record.type {
  56. case .HANDSHAKE:
  57. let handshake = try readHandshake(socket)
  58. switch handshake.type {
  59. case Handshake.Typo.CLIENT_HELLO:
  60. let _ = try readClientHello(handshake.message)
  61. default:
  62. print("default")
  63. }
  64. print("handshake")
  65. case .CHANGE_CIPHER_SPEC:
  66. print("TODO")
  67. case .ALERT:
  68. print("TODO")
  69. case .APPLICATION_DATA:
  70. print("TODO")
  71. }
  72. return socket
  73. }
  74. public struct Record {
  75. public enum Typo: UInt8 { case CHANGE_CIPHER_SPEC = 20, ALERT = 21, HANDSHAKE = 22, APPLICATION_DATA = 23 }
  76. public var type: Typo
  77. public var version: UInt16
  78. public var length: UInt16
  79. }
  80. public func readRecord(_ socket: Socket) throws -> Record {
  81. let type = try socket.read()
  82. guard let validType = Record.Typo(rawValue: type) else {
  83. throw TLSError.UnknownTLSRecordType("Unknown record type: \(type)")
  84. }
  85. let version = try nextUInt16(socket)
  86. let lengthh = try nextUInt16(socket)
  87. return Record(type: validType, version: version, length: lengthh)
  88. }
  89. public struct Handshake {
  90. public enum Typo: UInt8 {
  91. case HELLO_REQUEST = 0, CLIENT_HELLO = 1, SERVER_HELLO = 2, FINISHED = 20
  92. case CERTIFICATE = 11, SERVER_KEY_EXCHANGE = 12, CERTIFICATE_REQUEST = 13
  93. case SERVER_DONE = 14, CERTIFICATE_VERIFY = 15, CLIENT_KEY_EXCHANGE = 16
  94. }
  95. public var type = Typo.HELLO_REQUEST
  96. public var message = [UInt8]()
  97. }
  98. public func readHandshake(_ socket: Socket) throws -> Handshake {
  99. let type = try socket.read()
  100. guard let validType = Handshake.Typo(rawValue: type) else {
  101. throw TLSError.UnknownHandshakeType("Unknown record type: \(type)")
  102. }
  103. var handshake = Handshake()
  104. handshake.type = validType
  105. let length2 = try socket.read()
  106. let length1 = try socket.read()
  107. let length0 = try socket.read()
  108. let length = [length0, length1, length2, 0].withUnsafeBufferPointer() { UnsafePointer<UInt32>(OpaquePointer($0.baseAddress!)).pointee }.littleEndian
  109. while UInt32(handshake.message.count) < length { handshake.message.append(try socket.read()) }
  110. return handshake
  111. }
  112. public struct ClientHello {
  113. public var version: UInt16 = 0
  114. public var random = [UInt8]()
  115. public var sessionId = [UInt8]()
  116. public var cipherSuites = [UInt16]()
  117. public var compressionMethods = [UInt8]()
  118. public var extensions = [(id: UInt16, data: [UInt8])]()
  119. }
  120. public func readClientHello(_ data: [UInt8]) throws -> ClientHello {
  121. var iterator = DataIterator(data[0..<data.count])
  122. guard let version = iterator.nextUInt16() else { throw TLSError.InvalidData("No version field.") }
  123. guard let random = iterator.next(32) else { throw TLSError.InvalidData("No random field.") }
  124. guard let sessionIdLen = iterator.nextByte() else { throw TLSError.InvalidData("No Session Id Length field.") }
  125. guard let sessionId = iterator.next(Int(sessionIdLen)) else { throw TLSError.InvalidData("No Session Id field.") }
  126. guard let cipherSuitesCount = iterator.nextUInt16(), cipherSuitesCount % 2 == 0 else {
  127. throw TLSError.InvalidData("No Cipher Suites Count field.")
  128. }
  129. var cipherSuites = [UInt16]()
  130. for _ in 0..<cipherSuitesCount/2 {
  131. guard let cipherSuiteId = iterator.nextUInt16() else { throw TLSError.InvalidData("No Cipher Suite Id field.") }
  132. cipherSuites.append(cipherSuiteId)
  133. }
  134. guard let compressionMethodsCount = iterator.nextByte() else {
  135. throw TLSError.InvalidData("No first byte of the version field in Hello message \(data)")
  136. }
  137. guard let compressionMethods = iterator.next(Int(compressionMethodsCount)) else {
  138. throw TLSError.InvalidData("No Compression Method field.")
  139. }
  140. guard let extensionsLength = iterator.nextUInt16() else { throw TLSError.InvalidData("No Extension Length field.") }
  141. guard let extensionsData = iterator.next(Int(extensionsLength)) else { throw TLSError.InvalidData("No Extension Data field.") }
  142. var extensionDataIterator = DataIterator(extensionsData[0..<extensionsData.count])
  143. var extensions = [(id: UInt16, data: [UInt8])]()
  144. while true {
  145. guard let extensionId = extensionDataIterator.nextUInt16() else {
  146. break
  147. }
  148. guard let extensionDataLength = extensionDataIterator.nextUInt16() else {
  149. throw TLSError.InvalidData("No first byte of the version field in Hello message \(data)")
  150. }
  151. guard let extensionData = extensionDataIterator.next(Int(extensionDataLength)) else {
  152. throw TLSError.InvalidData("No first byte of the version field in Hello message \(data)")
  153. }
  154. extensions.append((id: extensionId, data: extensionData))
  155. }
  156. return ClientHello(version: version, random: random, sessionId: sessionId,
  157. cipherSuites: cipherSuites, compressionMethods: compressionMethods, extensions: extensions)
  158. }
  159. }