More efficient RDD.count() implementation

classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
Report Content as Inappropriate

More efficient RDD.count() implementation


As I have written my own data source, I also wrote a custom RDD[Row]
implementation to provide getPartitions and compute overrides.
This works very well but doing some performance analysis, I see that for
any given pipeline fit operation, a fair amount of time is spent in the
RDD.count method.
Its default implementation in RDD.scala is to go through the entire
iterator, which in my case is counter productive because I already know
the number of rows there are in the RDD or any partition returned by
As an initial attempt, I declared the following in my custom RDD

   override def count(): Long = { reader.RowCount }

but this never gets called which upon further inspection makes perfect
sense. Indeed the internal code creates RDDs for every partition it has
to work on. And this is where I'm a bit stuck because I have no idea as
to how to override this creation.

Here is a call stack for a GBTRegressor run, but it's quite similar for
RandomForestRegressor or DecisionTreeRegressor.


Any suggestion would be much appreciated.


To unsubscribe e-mail: [hidden email]