From d9ec8e9c3520f153c297a7068da65296845d7fe3 Mon Sep 17 00:00:00 2001 From: Ronan Date: Thu, 13 Feb 2020 22:50:52 -0800 Subject: [PATCH] knn error handling --- Arrival-iOS2/AppData.swift | 24 +++++++++----- .../KNearestNeighborsClassifier.swift | 33 ++++++++++++++----- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/Arrival-iOS2/AppData.swift b/Arrival-iOS2/AppData.swift index b50a7c5..286f1f1 100644 --- a/Arrival-iOS2/AppData.swift +++ b/Arrival-iOS2/AppData.swift @@ -17,6 +17,7 @@ import JavaScriptCore import FirebasePerformance import FirebaseAnalytics import FirebaseRemoteConfig +import FirebaseCrashlytics let appDelegate = UIApplication.shared.delegate as! AppDelegate let context = appDelegate.persistentContainer.viewContext @@ -291,15 +292,22 @@ class AppData: NSObject, ObservableObject,CLLocationManagerDelegate { nNeighbors = 1 print(nNeighbors, "knn neighbors") - let knn = KNearestNeighborsClassifier(data: trainingData, labels: labels, nNeighbors: nNeighbors) - if (self.fromStation.abbr != "load") { - print("knn predicting", self.fromStation.abbr) - let predictionLabels = knn.predict([tripToDouble(day: day, hour: hour, fromStation: self.fromStation.abbr)]) - print(predictionLabels, "knn prediction labels") - predictionType = predictionLabels.map({ self.stationFromInt(label: $0) }) - print(predictionType, "knn prediction type") - priorities[predictionType[0]] = (JSON(priorities)[predictionType[0]].intValue + 100) * 10 + do { + let knn = try KNearestNeighborsClassifier(data: trainingData, labels: labels, nNeighbors: nNeighbors) + if (self.fromStation.abbr != "load") { + print("knn predicting", self.fromStation.abbr) + let predictionLabels = try knn.predict([tripToDouble(day: day, hour: hour, fromStation: self.fromStation.abbr)]) + print(predictionLabels, "knn prediction labels") + predictionType = predictionLabels.map({ self.stationFromInt(label: $0) }) + print(predictionType, "knn prediction type") + priorities[predictionType[0]] = (JSON(priorities)[predictionType[0]].intValue + 100) * 10 + } + } catch { + print("crash", error) + Crashlytics.crashlytics().record(error: error) + } + } do { try context.save() diff --git a/Arrival-iOS2/KNearestNeighborsClassifier.swift b/Arrival-iOS2/KNearestNeighborsClassifier.swift index 9291770..c3b00c6 100644 --- a/Arrival-iOS2/KNearestNeighborsClassifier.swift +++ b/Arrival-iOS2/KNearestNeighborsClassifier.swift @@ -8,32 +8,47 @@ import Darwin import Foundation - +enum knnError: Error { + case cantFindMajority + case expectedNN(nNeighbors: Int, dataCount: Int) + case expectedDataCount(dataCount: Int, labelsCount: Int) + +} public class KNearestNeighborsClassifier { private let data: [[Double]] private let labels: [Int] private let nNeighbors: Int - public init(data: [[Double]], labels: [Int], nNeighbors: Int = 3) { + public init(data: [[Double]], labels: [Int], nNeighbors: Int = 3) throws { self.data = data self.labels = labels self.nNeighbors = nNeighbors guard nNeighbors <= data.count else { - fatalError("Expected `nNeighbors` (\(nNeighbors)) <= `data.count` (\(data.count))") + throw(knnError.expectedNN(nNeighbors: nNeighbors, dataCount: data.count)) } guard data.count == labels.count else { - fatalError("Expected `data.count` (\(data.count)) == `labels.count` (\(labels.count))") + throw(knnError.expectedDataCount(dataCount: data.count, labelsCount: labels.count)) } } - public func predict(_ xTests: [[Double]]) -> [Int] { - return xTests.map({ + public func predict(_ xTests: [[Double]]) throws -> [Int] { + do { + return try xTests.map({ let knn = kNearestNeighbors($0) - return kNearestNeighborsMajority(knn) + do { + let result = try kNearestNeighborsMajority(knn) + return result + } catch { + throw(error) + } + }) + } catch { + throw(error) + } } private func distance(_ xTrain: [Double], _ xTest: [Double]) -> Double { @@ -56,7 +71,7 @@ public class KNearestNeighborsClassifier { return Array(kNearestNeighborsSorted) } - private func kNearestNeighborsMajority(_ knn: [(key: Double, value: Int)]) -> Int { + private func kNearestNeighborsMajority(_ knn: [(key: Double, value: Int)]) throws -> Int { var labels = [Int : Int]() for neighbor in knn { @@ -69,6 +84,6 @@ public class KNearestNeighborsClassifier { } } - fatalError("Cannot find the majority.") + throw(knnError.cantFindMajority) } }