Point-in-time joins with Spark

Point-in-Time (PIT) joins in Hopsworks Feature Store

In order to create a training dataset, data scientists usually have to generate information about the future by putting themselves back in time, or travelling back in time.

Let’s take the case of churn prediction: We want to predict which of our users will churn within the next few months. To train such a model we need to construct a training dataset containing rows for each of our users, one column indicating whether a user churned and X additional columns with features about the user, such as his historical purchases or interactions with the product.

Since we don’t know the future yet (it’s what we want to predict), to generate such prediction targets, we have to go back in time and determine which users churned in the last couple of months.

Snapshot based time travel

If you have looked at our other time travel notebooks, the most simple solution to this problem would be to choose a single cutoff time for all our users. Using the time travel capabilities of the Feature Store with Apache Hudi make it easy enough to fetch data from a single point in time. This can work if features aren’t frequently updated, however, when prediction target events of different users lie far apart and features are updated frequently you might leave a lot of information unused, hurting the accuracy of your model.

Individual point-in-time correct cutoffs

A better approach includes using all information up to the time the prediction target event happened. So that means we still go back, let’s say a maximum 6 months, in time, however, we want to remember the prediction target event time stamp on a user level (row level) and find the latest values of our prediction features before this point in time. The following diagrams illustrates this approach for one user:

image.png

A problem often occurring at this stage is the possibility of leaking information from the future (light red signals) into the training dataset. This means using feature signals which happened after the prediction target event has to be strictly avoided. What we want to retrieve from the feature store are the green, and the green only, feature signals.

In this case it is not as simple as performing a time travel query with a single time stamp. This challenge is solved by a point-in-time correct join instead. Point-in-time joins prevent feature leakage by recreating the state of the world at a single point in time for every entity or primary key value (user in our case).

Hopsworks Feature Store abstracts this complexity away by simply telling it where to find the relevant event time stamps for feature groups. We will go through the process in the rest of the notebook.

Event-time enabled Feature Groups

For the Feature Store to be able to perform a PIT join, we need to tell it where to find the event time stamp within each feature group. Event-time is a timestamp indicating the instant in time when an event happened at the source of the event, so this is not an ingestion time stamp or the like, but instead should originate from your source systems.

To “event-time enable” a feature group, you set the event_time argument at feature group creation. We are using simple Integers to indicate the timestamps, for better readability.

For simplicity we will create three feature groups

  • marketing with customer id, complaints in the last 7 days, outbound activities in the last 7 days and coupons received in the last 14 days;
  • contract with customer id, contract_type;
  • churn which will contain labels wether loan was cancelled 1 or rejected 0 with a timestamp when the contract was created or cancelled, note that this feature group has a composite primary key of customer id and a contract id, referring to the id in the contract feature group.
import com.logicalclocks.hsfs._
import com.logicalclocks.hsfs.constructor._
import scala.collection.JavaConversions._
import collection.JavaConverters._

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.types._

val connection = HopsworksConnection.builder().build();
val fs = connection.getFeatureStore();
Starting Spark application
IDYARN Application IDKindStateSpark UIDriver log
25application_1630405039893_0031sparkidleLinkLink
SparkSession available as 'spark'.
import com.logicalclocks.hsfs._
import com.logicalclocks.hsfs.constructor._
import scala.collection.JavaConversions._
import collection.JavaConverters._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._
connection: com.logicalclocks.hsfs.HopsworksConnection = com.logicalclocks.hsfs.HopsworksConnection@27bb0197
fs: com.logicalclocks.hsfs.FeatureStore = FeatureStore{id=67, name='demo_fs_meb10000_featurestore', projectId=119, featureGroupApi=com.logicalclocks.hsfs.metadata.FeatureGroupApi@1d123d24}
val marketingSchema = scala.collection.immutable.List(
  StructField("customer_id", IntegerType, true),
  StructField("ts", LongType, true),
  StructField("complaints_7d", IntegerType, true),
  StructField("outbound_7d", IntegerType, true),
  StructField("coupon_14d", IntegerType, true)        
)

val contractSchema = scala.collection.immutable.List(
  StructField("contract_id", IntegerType, true),
  StructField("ts", LongType, true),
  StructField("contract_type", StringType, true)
)

val churnSchema = scala.collection.immutable.List(
  StructField("customer_id", IntegerType, true),
  StructField("contract_id", IntegerType, true),
  StructField("ts", LongType, true),
  StructField("contract_churn", IntegerType, true)       
)
marketingSchema: List[org.apache.spark.sql.types.StructField] = List(StructField(customer_id,IntegerType,true), StructField(ts,LongType,true), StructField(complaints_7d,IntegerType,true), StructField(outbound_7d,IntegerType,true), StructField(coupon_14d,IntegerType,true))
contractSchema: List[org.apache.spark.sql.types.StructField] = List(StructField(contract_id,IntegerType,true), StructField(ts,LongType,true), StructField(contract_type,StringType,true))
churnSchema: List[org.apache.spark.sql.types.StructField] = List(StructField(customer_id,IntegerType,true), StructField(contract_id,IntegerType,true), StructField(ts,LongType,true), StructField(contract_churn,IntegerType,true))

We will first load the Churn Feature Group with the initial contracts. We can assume there is a job running inserting new rows for every new contract into this feature group.

val newContracts = Seq(
    Row(1, 100, 10010L, 0),
    Row(2, 101, 10017L, 0),
    Row(3, 102, 10035L, 0),
    Row(4, 103, 10023L, 0),
    Row(5, 104, 10546L, 0),
    Row(6, 105, 10213L, 0),
    Row(7, 106, 10056L, 0),
    Row(8, 107, 10012L, 0)
)

val newContractsDf = spark.createDataFrame(
    spark.sparkContext.parallelize(newContracts),
    StructType(churnSchema)
)
newContracts: Seq[org.apache.spark.sql.Row] = List([1,100,10010,0], [2,101,10017,0], [3,102,10035,0], [4,103,10023,0], [5,104,10546,0], [6,105,10213,0], [7,106,10056,0], [8,107,10012,0])
newContractsDf: org.apache.spark.sql.DataFrame = [customer_id: int, contract_id: int ... 2 more fields]

At the same time some contracts will be cancelled and inserted into the feature group over time. We will perform this insertion with a bit of time difference, so that we can later demonstrate the capabilities of PIT joins together with time-travel queries.

val churnedContracts = Seq(
    Row(1, 100, 10356L, 1),
    Row(5, 104, 10692L, 1),
    Row(6, 105, 10375L, 1),
    Row(8, 107, 10023L, 1)
)

val churnedContractsDf = spark.createDataFrame(
    spark.sparkContext.parallelize(churnedContracts),
    StructType(churnSchema)
)
churnedContracts: Seq[org.apache.spark.sql.Row] = List([1,100,10356,1], [5,104,10692,1], [6,105,10375,1], [8,107,10023,1])
churnedContractsDf: org.apache.spark.sql.DataFrame = [customer_id: int, contract_id: int ... 2 more fields]

Now let’s create some mock data for the secondary feature groups, containing the actual explanatory features used to predict the churn.

The contract feature group is a feature group that gets new contracts appended with information about the contract itself only, such as the type of contract. Hence all timestamps are the same as the inital contract creation in the churn feature group.

val contracts = Seq(
    Row(100, 10010L, "Long-term"),
    Row(101, 10017L, "Short-term"),
    Row(102, 10035L, "Trial"),
    Row(103, 10023L, "Short-term"),
    Row(104, 10546L, "Long-term"),
    Row(105, 10213L, "Trial"),
    Row(106, 10056L, "Long-term"),
    Row(107, 10012L, "Short-term")
)

val contractsDf = spark.createDataFrame(
    spark.sparkContext.parallelize(contracts),
    StructType(contractSchema)
)
contracts: Seq[org.apache.spark.sql.Row] = List([100,10010,Long-term], [101,10017,Short-term], [102,10035,Trial], [103,10023,Short-term], [104,10546,Long-term], [105,10213,Trial], [106,10056,Long-term], [107,10012,Short-term])
contractsDf: org.apache.spark.sql.DataFrame = [contract_id: int, ts: bigint ... 1 more field]

The marketing activities feature group contains features related to outbound and inbound contacts to the customer. You can imagine these to be computed by a streaming application, updating the features every time new events arrive.

In the point in time join we want to get the latest of these updates just before or at the same time as the prediction event happened. Contracts can be in the training dataset twice, once when they haven’t churned yet and then when they churned. The rows which should be picked for either of the target events are marked with a comment.

At the same time, always the latest state of the customer’s marketing profile should be available in the online feature store.

val marketing = Seq(
    Row(1, 10010L, 0, 0, 1), // this one
    Row(1, 10174L, 3, 0, 4),
    Row(1, 10257L, 7, 0, 3),
    Row(1, 10352L, 3, 0, 5), // this one
    Row(1, 10753L, 0, 0, 0),
    Row(1, 10826L, 0, 0, 1), // online feature store
    
    Row(2, 10017L, 0, 1, 1), // this one
    Row(2, 10021L, 0, 1, 1),
    Row(2, 10034L, 0, 1, 2), // online feature store
    
    Row(3, 10035L, 1, 3, 0), // this one
    Row(3, 10275L, 1, 2, 0),
    
    Row(5, 10546L, 0, 1, 0), // this one
    Row(5, 10598L, 2, 2, 1), // this one
    Row(5, 13567L, 0, 1, 0),    
    Row(5, 16245L, 0, 1, 0), // online feature store
    
    Row(6, 10213L, 0, 0, 1), // this one
    Row(6, 10234L, 0, 5, 0), // this one
    Row(6, 10436L, 0, 0, 1), // online feature store
    
    Row(7, 10056L, 0, 0, 0), // this one
    Row(7, 10056L, 0, 1, 0),
    Row(7, 10056L, 0, 2, 1),
    Row(7, 10056L, 0, 3, 0), // online feature store
    
    Row(8, 10012L, 0, 0, 1), // this one
    Row(8, 10023L, 10, 0, 1), // this one
    Row(8, 10033L, 0, 0, 1), // online feature store
)

val marketingDf = spark.createDataFrame(
    spark.sparkContext.parallelize(marketing),
    StructType(marketingSchema)
)
marketing: Seq[org.apache.spark.sql.Row] = List([1,10010,0,0,1], [1,10174,3,0,4], [1,10257,7,0,3], [1,10352,3,0,5], [1,10753,0,0,0], [1,10826,0,0,1], [2,10017,0,1,1], [2,10021,0,1,1], [2,10034,0,1,2], [3,10035,1,3,0], [3,10275,1,2,0], [5,10546,0,1,0], [5,10598,2,2,1], [5,13567,0,1,0], [5,16245,0,1,0], [6,10213,0,0,1], [6,10234,0,5,0], [6,10436,0,0,1], [7,10056,0,0,0], [7,10056,0,1,0], [7,10056,0,2,1], [7,10056,0,3,0], [8,10012,0,0,1], [8,10023,10,0,1], [8,10033,0,0,1])
marketingDf: org.apache.spark.sql.DataFrame = [customer_id: int, ts: bigint ... 3 more fields]

Create the feature groups

We are now ready to create our three feature groups. Note the additional argument we are passing, in order to tell the Feature Store which column should be used as event_time.

val marketingFg = (fs.createFeatureGroup()
                   .name("marketing")
                   .version(1)
                   .description("Features about inbound/outbound communication with customers")
                   .onlineEnabled(true)
                   .statisticsConfig(new StatisticsConfig(false, false, false, false))
                   .primaryKeys(Seq("customer_id"))
                   .eventTime("ts")
                   .build())
warning: there was one deprecation warning (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
marketingFg: com.logicalclocks.hsfs.FeatureGroup = com.logicalclocks.hsfs.FeatureGroup@2ac49e4a
marketingFg.save(marketingDf)
val contractsFg = (fs.createFeatureGroup()
                   .name("contracts")
                   .version(1)
                   .description("Contract information features")
                   .onlineEnabled(true)
                   .statisticsConfig(new StatisticsConfig(false, false, false, false))
                   .primaryKeys(Seq("contract_id"))
                   .eventTime("ts")
                   .build())
warning: there was one deprecation warning (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
contractsFg: com.logicalclocks.hsfs.FeatureGroup = com.logicalclocks.hsfs.FeatureGroup@21b3a6a3
contractsFg.save(contractsDf)
val churnFg = (fs.createFeatureGroup()
                   .name("churn")
                   .version(1)
                   .description("Customer/contract information about activity of contract")
                   .onlineEnabled(true)
                   .statisticsConfig(new StatisticsConfig(false, false, false, false))
                   .primaryKeys(Seq("customer_id", "contract_id"))
                   .eventTime("ts")
                   .build())
warning: there was one deprecation warning (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
churnFg: com.logicalclocks.hsfs.FeatureGroup = com.logicalclocks.hsfs.FeatureGroup@21cbbdb3
churnFg.save(newContractsDf)

We insert the churned contracts, in a separate upsert step.

churnFg.insert(churnedContractsDf)

Constructing a Point-in-Time Join

The Feature Store HSFS API comes with a Query abstraction. Operations such as .join or .select on a feature group return such Query objects. Queries are solely based on metadata up until they are saved as training datasets or read into Dataframes.

There are two requirements to construct a point in time join:

  1. All feature groups have to be event time enabled. If there is no event timestamp available, you can fall back on creating an ingestion timestamp in your feature engineering pipeline.
  2. The label (feature) group is the Feature Group that contains the column that is defined as the “label” column when creating the training dataset. The left-most feature group has to be the label group, meaning this is the feature group of which the timestamps will be used as reference timestamp against which the explaining features are joined (dark red dots in the diagram above). In our case this is the churn feature group.

So let’s construct the join:

val query = (churnFg.selectAll().join(contractsFg.select(Seq("contract_type")), Seq("contract_id"))
                                .join(marketingFg.select(Seq("complaints_7d", "outbound_7d", "coupon_14d")), Seq("customer_id")))
warning: there were four deprecation warnings (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
query: com.logicalclocks.hsfs.constructor.Query =
WITH right_fg0 AS (SELECT *
FROM (SELECT `fg2`.`customer_id`, `fg2`.`contract_id`, `fg2`.`ts`, `fg2`.`contract_churn`, `fg0`.`contract_type`, RANK() OVER (PARTITION BY `fg2`.`contract_id`, `fg2`.`ts` ORDER BY `fg0`.`ts` DESC) pit_rank_hopsworks
FROM `fg2` `fg2`
INNER JOIN `fg0` `fg0` ON `fg2`.`contract_id` = `fg0`.`contract_id` AND `fg2`.`ts` >= `fg0`.`ts`) NA
WHERE `pit_rank_hopsworks` = 1), right_fg1 AS (SELECT *
FROM (SELECT `fg2`.`customer_id`, `fg2`.`contract_id`, `fg2`.`ts`, `fg2`.`contract_churn`, `fg1`.`complaints_7d`, `fg1`.`outbound_7d`, `fg1`.`coupon_14d`, RANK() OVER (PARTITION BY `fg2`.`customer_id`, `fg2`.`ts` ORDER BY `fg1`.`ts` DESC) pit_rank_hopsworks
FROM `fg2` `fg2`
INNER JOIN `fg1` `fg1` ON `fg2`.`customer_id` = `fg1`...

You can print and look at the constructed SQL join query. However, there is no need to go into this complexity, as it will all be handled by the feature store.

As explained above, the Query itself does not perform any join yet until you call an action such as .read on it:

query.read().show(20)
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+
|customer_id|contract_id|   ts|contract_churn|contract_type|complaints_7d|outbound_7d|coupon_14d|
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+
|          6|        105|10375|             1|        Trial|            0|          5|         0|
|          5|        104|10692|             1|    Long-term|            2|          2|         1|
|          5|        104|10546|             0|    Long-term|            0|          1|         0|
|          2|        101|10017|             0|   Short-term|            0|          1|         1|
|          7|        106|10056|             0|    Long-term|            0|          0|         0|
|          8|        107|10023|             1|   Short-term|           10|          0|         1|
|          1|        100|10356|             1|    Long-term|            3|          0|         5|
|          1|        100|10010|             0|    Long-term|            0|          0|         1|
|          6|        105|10213|             0|        Trial|            0|          0|         1|
|          8|        107|10012|             0|   Short-term|            0|          0|         1|
|          3|        102|10035|             0|        Trial|            1|          3|         0|
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+

Filters on point-in-time queries

As with any other query, it is possible to apply filters to it. For example, if you are planning to train only a model for a certain contract type:

(churnFg.selectAll().join(
                        contractsFg.select(Seq("contract_type"))
                                    .filter(contractsFg.getFeature("contract_type").eq("Long-term"))
                        , Seq("contract_id"))
                     .join(marketingFg.select(Seq("complaints_7d", "outbound_7d", "coupon_14d")), Seq("customer_id")).read().show(20))
warning: there were four deprecation warnings (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+
|customer_id|contract_id|   ts|contract_churn|contract_type|complaints_7d|outbound_7d|coupon_14d|
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+
|          5|        104|10692|             1|    Long-term|            2|          2|         1|
|          5|        104|10546|             0|    Long-term|            0|          1|         0|
|          7|        106|10056|             0|    Long-term|            0|          0|         0|
|          1|        100|10356|             1|    Long-term|            3|          0|         5|
|          1|        100|10010|             0|    Long-term|            0|          0|         1|
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+

Combining time-travel and point-in-time join in one query

We performed a separate upsert on our churn feature group in order to be able to demonstrate time-travel, so let’s look at the commit information:

val commitDetails = churnFg.commitDetails()
commitDetails: java.util.Map[Long,java.util.Map[String,String]] = {1631199563000={committedOn=20210909145923, rowsUpdated=0, rowsDeleted=0, rowsInserted=8}, 1631199588000={committedOn=20210909145948, rowsUpdated=0, rowsDeleted=0, rowsInserted=4}}

As you can see, there are two commits: (1) the initial commit, which inserted the eight contracts. And (2) the upsert commit with the additinal four churned contracts.

Now we would like to query the state of the Feature Store at the time before the churned contracts were ingested:

// getting the correct timestamp, we sort by the commit time and get the string representation of the commit time
val committedOnList = commitDetails.values().map(c => c.get("committedOn")).toList.sorted
val committedOn = committedOnList(0)
warning: there was one deprecation warning (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
committedOnList: List[String] = List(20210909145923, 20210909145948)
committedOn: String = 20210909145923
churnFg.selectAll().asOf(committedOn).read().show(20)
+-----------+-----------+-----+--------------+
|customer_id|contract_id|   ts|contract_churn|
+-----------+-----------+-----+--------------+
|          3|        102|10035|             0|
|          2|        101|10017|             0|
|          8|        107|10012|             0|
|          6|        105|10213|             0|
|          1|        100|10010|             0|
|          4|        103|10023|             0|
|          5|        104|10546|             0|
|          7|        106|10056|             0|
+-----------+-----------+-----+--------------+

And combining this with our point-in-time join:

query.asOf(committedOn).read().show(20)
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+
|customer_id|contract_id|   ts|contract_churn|contract_type|complaints_7d|outbound_7d|coupon_14d|
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+
|          5|        104|10546|             0|    Long-term|            0|          1|         0|
|          2|        101|10017|             0|   Short-term|            0|          1|         1|
|          7|        106|10056|             0|    Long-term|            0|          0|         0|
|          1|        100|10010|             0|    Long-term|            0|          0|         1|
|          6|        105|10213|             0|        Trial|            0|          0|         1|
|          8|        107|10012|             0|   Short-term|            0|          0|         1|
|          3|        102|10035|             0|        Trial|            1|          3|         0|
+-----------+-----------+-----+--------------+-------------+-------------+-----------+----------+

Creating the training dataset

Creating the training dataset is now as simple as initializing the metadata with .create_training_dataset() and subsequently persisting it and materializing the query with .save().

val td = (fs.createTrainingDataset()
          .name("churn_model")
          .version(1)
          .dataFormat(DataFormat.CSV)
          .label(Seq("contract_churn"))
          .build())
warning: there was one deprecation warning (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
td: com.logicalclocks.hsfs.TrainingDataset = com.logicalclocks.hsfs.TrainingDataset@3c3e22a8
td.save(query)

Querying the online feature store

We can reproduce the query and get the latest feature vector for each contract from the online feature store now.

NOTE: - Any applied filter at the time of creating the training dataset is not applied during serving time. It is the responsibility of the application performing the lookup, to only provide contract ids which belong for example to a certain category. The reason for that is that we always want to get back a vector, applying a filter, might lead to no results.

val td = fs.getTrainingDataset("churn_model")
td: com.logicalclocks.hsfs.TrainingDataset = com.logicalclocks.hsfs.TrainingDataset@6042db18
td.getServingVector(Map("customer_id" -> new Integer(1), "contract_id" -> new Integer(100)))
warning: there was one deprecation warning (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
res19: java.util.List[Object] = [1, 100, 10356, Long-term, 0, 0, 1]
td.getServingVector(Map("customer_id" -> new Integer(2), "contract_id" -> new Integer(101)))
warning: there was one deprecation warning (since 2.12.0); for details, enable `:setting -deprecation' or `:replay -deprecation'
res20: java.util.List[Object] = [2, 101, 10017, Short-term, 0, 1, 2]