I am having trouble using an h2o model (in mojo format) on a Spark cluster, but only when I try to run it in parallel, not when I use collect and run it on the driver.
Since the dataframe I am predicting on has > 100 features, I am using the following function to convert dataframe rows to RowData format for h2o (from here):
def rowToRowData(df: DataFrame, row: Row): RowData = {
val rowAsMap = row.getValuesMap[Any](df.schema.fieldNames)
val rowData = rowAsMap.foldLeft(new RowData()) { case (rd, (k,v)) =>
if (v != null) { rd.put(k, v.toString) }
rd
}
rowData
}
Then, I import the mojo model and create an easyPredictModel wrapper
val mojo = MojoModel.load("/path/to/mojo.zip")
val easyModel = new EasyPredictModelWrapper(mojo)
Now, I can make predictions on my dataframe (df) by mapping over the rows if I collect it first, so the following works:
val predictions = df.collect().map { r =>
val rData = rowToRowData(df, r) . // convert row to RowData using function
val prediction = easyModel.predictBinomial(rData).label
(r.getAs[String]("id"), prediction.toInt)
}
.toSeq
.toDF("id", "prediction")
However, I wish to do this in parallel on the cluster since the final df will be too large to collect on the driver. But if I try to run the same code without collecting first:
val predictions = df.map { r =>
val rData = rowToRowData(df, r)
val prediction = easyModel.predictBinomial(rData).label
(r.getAs[String]("id"), prediction.toInt)
}
.toDF("id", "prediction")
I get the following errors:
18/01/03 11:34:59 WARN TaskSetManager: Lost task 0.0 in stage 118.0 (TID 9914, 213.248.241.182, executor 0): java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2133)
at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1305)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2024)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373)
at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:80)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:335)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
So it looks like a datatype mismatch. I have tried converting the dataframe to an rdd first (i.e. df.rdd.map, but get the same errors), doing df.mapPartition, or placing the rowToData function code within the map, but nothing has worked so far.
Any ideas on the best way to achieve this?