diff --git a/src/clj_ml/data.clj b/src/clj_ml/data.clj index 61e2707..e085648 100644 --- a/src/clj_ml/data.clj +++ b/src/clj_ml/data.clj @@ -234,11 +234,17 @@ [^Instances ds] (Instances. ds 0)) -(defn dataset-get-class +(defn dataset-class-index "Returns the index of the class attribute for this dataset" [^Instances dataset] (.classIndex dataset)) +(defn dataset-class-name + "Returns the name of the class attribute in keyword form. Returns nil if not set." + [^Instances dataset] + (when (> (dataset-class-index dataset) -1) + (keyword-name (.classAttribute dataset)))) + (defn dataset-nominal? "Returns boolean indicating if the class attribute is nominal" [^Instances dataset] diff --git a/test/clj_ml/data_test.clj b/test/clj_ml/data_test.clj index 43d34c1..c5af8b9 100644 --- a/test/clj_ml/data_test.clj +++ b/test/clj_ml/data_test.clj @@ -28,9 +28,9 @@ (deftest dataset-make-dataset-with-default-class (let [ds (clj-ml.data/make-dataset :test [:a :b {:c [:d :e]}] [] {:class :c}) ds2 (clj-ml.data/make-dataset :test [:a :b {:c [:d :e]}] [] {:class 2})] - (is (= (clj-ml.data/dataset-get-class ds) - 2)) - (is (= (clj-ml.data/dataset-get-class ds2) + (is (= (clj-ml.data/dataset-class-name ds) + :c)) + (is (= (clj-ml.data/dataset-class-index ds2) 2)))) @@ -162,3 +162,10 @@ (is (= 0 (dataset-count headers))) (is (= "test" (dataset-name headers))) (is (= [:a {:b [:foo :bar]}] (dataset-format headers)))))) + + +(deftest dataset-class-helpers + (let [ds (make-dataset "test" [:a {:b [:foo :bar]}] [[1 :foo] [2 :bar]])] + (is (= nil (dataset-class-name ds))) + (dataset-set-class ds :b) + (is (= :b (dataset-class-name ds)))))