Example MLI problem

Problem name

Classification of iris dataset

Problem Statement

Given the iris dataset, you need to train your model to predict to predict the Species of iris plant based on their sepal and petal details.

Note that the above problem statement is too short and does not explain the problem in detail. It's provided for sample purpose only.

Training dataset

file: /data/training/iris_train.csv

Columns

Sepal.Length
Sepal.Width
Petal.Length
Petal.Width
Species

Output format

You need to write a csv file /code/iris_prediction.csv with the following data. The file line should have the column names id and Species. Next n lines should contain predicted values of Species for each id in the test data.

Example

id,Species
3,setosa
8,setosa

Datasets

Training dataset

This dataset will be publicly available to the user. He can also download it. Download training dataset

Test dataset

This dataset will not be publicly available to the user. It will be available in the evaluation environment only. Download test dataset

Solution

library(caret)
library(klaR)

data_train = read.csv("/data/training/iris_train.csv")
data_test = read.csv("/data/test/iris_test.csv")

# User can save any number of plots to `/code/output` directory
# This will be visible in the reports
png(filename = "/code/output/pairs.png")
pairs(iris[1:4], pch = 21, bg = c("red", "green3", "blue")[unclass(iris$Species)])
dev.off()

model = NaiveBayes(Species~., data = data_train)

predictions = predict(model, data_test)
data_test$Species = predictions$class

write.csv(data_test[,c("id","Species")],file = "/code/iris_prediction.csv",quote = FALSE,row.names = FALSE)

Generated plots

pairs

Testcases

Note that plots can be generated by the evaluation script also.

data_train = read.csv("/data/training/iris_train.csv")
data_test = read.csv("/code/iris_prediction.csv") # File generated by the user
data_actual = read.csv("/data/test/iris_actual.csv") # Contains the actual values for Species in test data

final = merge(data_test, data_actual, by = "id")
# Exit if output file does not contain the expected data
stopifnot(nrow(final) == nrow(data_actual))

cm = as.matrix(table(Actual = final$Species.y, Predicted = final$Species.x))
n = sum(cm)
diag = diag(cm)
accuracy = sum(diag) / n
write(c("accuracy", accuracy),file = "/code/errormetrics",ncolumns = 2) # This is required to show the error metrics value to the user
# Return exit code 1 if accuracy is less
stopifnot(accuracy > 0.7)

results matching ""

    No results matching ""