Procházet zdrojové kódy

SQLite statements uses binding to pass the values.

Damian Kołakowski před 10 roky
rodič
revize
26b7fb54d6

+ 11 - 14
Sources/Reflection.swift

@@ -48,17 +48,17 @@ public extension DatabaseReflectionProtocol {
         return ("\(mirror.subjectType)", fields)
     }
     
-    public func schemeWithValuesAsString() -> (String, [String: String?]) {
+    public func schemeWithValuesAsString() -> (String, [(String, String?)]) {
         let (name, fields) = schemeWithValuesMethod2()
-        var map = [String: String?]()
+        var map = [(String, String?)]()
         for (key, value) in fields {
             // TODO - Replace this by extending all supported types by a protocol.
             // Example: 'extenstion Int: DatabaseConvertible { convert() -> something ( not necessary String type ) }'
-            if let intValue    = value as? Int    { map[key] = String(intValue) }
-            if let int32Value  = value as? Int32  { map[key] = String(int32Value) }
-            if let int64Value  = value as? Int64  { map[key] = String(int64Value) }
-            if let doubleValue = value as? Double { map[key] = String(doubleValue) }
-            if let stringValue = value as? String { map[key] = stringValue }
+            if let intValue    = value as? Int    { map.append((key, String(intValue))) }
+            if let int32Value  = value as? Int32  { map.append((key, String(int32Value))) }
+            if let int64Value  = value as? Int64  { map.append((key, String(int64Value))) }
+            if let doubleValue = value as? Double { map.append((key, String(doubleValue))) }
+            if let stringValue = value as? String { map.append((key, stringValue)) }
         }
         return (name, map)
     }
@@ -86,13 +86,10 @@ public extension DatabaseReflectionProtocol {
             throw SQLiteError.OpenFailed("Database connection is not opened.")
         }
         let (name, fields) = schemeWithValuesAsString()
-        let create = "CREATE TABLE IF NOT EXISTS \(name) (" + fields.keys.map { "\($0) TEXT" }.joinWithSeparator(", ")  + ");"
-        try database.exec(create)
-        // TODO - Replace this with the binding to avoid SQL injection.
-        let ordered = fields.keys.reduce([(String, String)]()) { $0 + [($1, "\"\(fields[$1])\"")] }
-        let names = ordered.map({ $0.0 }).joinWithSeparator(", ")
-        let values = ordered.map({ $0.1 }).joinWithSeparator(", ")
-        try database.exec("INSERT INTO \(name)(" + names + ") VALUES(" + values  + ");" )
+        try database.exec("CREATE TABLE IF NOT EXISTS \(name) (" + fields.map { "\($0.0) TEXT" }.joinWithSeparator(", ")  + ");")
+        let names = fields.map { "\($0.0)" }.joinWithSeparator(", ")
+        let values = Array(count: fields.count, repeatedValue: "?").joinWithSeparator(", ")
+        try database.exec("INSERT INTO \(name)(" + names + ") VALUES(" + values  + ");", fields.map { $0.1 })
     }
     
 }

+ 43 - 23
Sources/SQLite.swift

@@ -10,6 +10,7 @@ import Foundation
 public enum SQLiteError: ErrorType {
     case OpenFailed(String?)
     case ExecFailed(String?)
+    case BindFailed(String?)
 }
 
 public class SQLite {
@@ -35,31 +36,46 @@ public class SQLite {
         try exec(sql, nil)
     }
     
-    public func exec(sql: String, _ stepCallback: ([String: String] -> Void)?) throws {
-        var errorMessagePointer = UnsafeMutablePointer<Int8>()
-        var execCContext = ExecCContext(callback: stepCallback)
-        let execResult = sql.withCString {
-            sqlite3_exec(databaseConnection, $0, { (context, count, values, names) -> Int32 in
-                var content = [String: String]()
-                for i in 0..<count {
-                    if let name = String.fromCString(names.advancedBy(Int(i)).memory),
-                        let value = String.fromCString(values.advancedBy(Int(i)).memory) {
-                            content[name] = value
+    public func exec(sql: String, _ params: [String?]? = nil, _ step: ([String: String?] -> Void)? = nil) throws {
+        var statement = COpaquePointer()
+        let prepareResult = sql.withCString { sqlite3_prepare_v2(databaseConnection, $0, Int32(sql.utf8.count), &statement, nil) }
+        guard prepareResult == SQLITE_OK else {
+            throw SQLiteError.ExecFailed(String.fromCString(sqlite3_errmsg(databaseConnection)))
+        }
+        for (index, value) in (params ?? [String?]()).enumerate() {
+            let bindResult = value?.withCString({ sqlite3_bind_text(statement, index + 1, $0, -1 /* take zero terminator. */) { _ in } })
+                    ?? sqlite3_bind_null(statement, index + 1)
+            guard bindResult == SQLITE_OK else {
+                throw SQLiteError.BindFailed(String.fromCString(sqlite3_errmsg(databaseConnection)))
+            }
+        }
+        while true {
+            let stepResult = sqlite3_step(statement)
+            switch stepResult {
+            case SQLITE_ROW:
+                var content = [String: String?]()
+                for i in 0..<sqlite3_column_count(statement) {
+                    if let name = String.fromCString(UnsafePointer<CChar>(sqlite3_column_name(statement, i))) {
+                        let pointer = sqlite3_column_text(statement, i)
+                        if pointer == nil {
+                            content[name] = nil
+                        } else {
+                            content[name] = String.fromCString(UnsafePointer<CChar>(pointer))
+                        }
                     }
                 }
-                if let callback = UnsafeMutablePointer<ExecCContext>(context).memory.callback {
-                    callback(content)
-                }
-                return SQLITE_OK
-            }, &execCContext, &errorMessagePointer)
-        }
-        guard execResult == SQLITE_OK else {
-            let errorDetails = String.fromCString(errorMessagePointer)
-            sqlite3_free(errorMessagePointer)
-            throw SQLiteError.ExecFailed(errorDetails)
+                step?(content)
+            case SQLITE_DONE:
+                return
+            case SQLITE_ERROR:
+                throw SQLiteError.ExecFailed("sqlite3_step() returned SQLITE_ERROR.")
+            default:
+                throw SQLiteError.ExecFailed("Unknown result for sqlite3_step(): \(stepResult)")
+            }
         }
     }
     
+    
     public func enumerate(sql: String) throws -> StatmentSequence {
         var statement = COpaquePointer()
         let prepareResult = sql.withCString { sqlite3_prepare_v2(databaseConnection, $0, Int32(sql.utf8.count), &statement, nil) }
@@ -78,9 +94,13 @@ public class SQLite {
             case SQLITE_ROW:
                 var content = [String: String]()
                 for i in 0..<sqlite3_column_count(statement) {
-                    if let name = String.fromCString(UnsafePointer<CChar>(sqlite3_column_name(statement, i))),
-                        let value = String.fromCString(UnsafePointer<CChar>(sqlite3_column_text(statement, i))) {
-                            content[name] = value
+                    if let name = String.fromCString(UnsafePointer<CChar>(sqlite3_column_name(statement, i))) {
+                        let pointer = sqlite3_column_text(statement, i)
+                        if pointer == nil {
+                            content[name] = nil
+                        } else {
+                            content[name] = String.fromCString(UnsafePointer<CChar>(pointer))
+                        }
                     }
                 }
                 return content

+ 1 - 1
SwifterTestsCommon/SwifterTestsSQLite.swift

@@ -53,7 +53,7 @@ class SwifterTestsSQLite: XCTestCase {
             }
             XCTAssert(counter == 2, "Database should have two rows.")
             
-            try database.close()
+            database.close()
         } catch {
             XCTAssert(false, "Database manipulation should not throw any exceptions.")
         }