diff --git a/fraud/1_return.py b/fraud/1_return.py index 140a3c4..a4a98d8 100644 --- a/fraud/1_return.py +++ b/fraud/1_return.py @@ -20,7 +20,7 @@ class Transaction: # Computed properties clean_memo: str - is_nsf: bool + is_nsf: bool # non-sufficient funds @features diff --git a/fraud_case_study/datasources.py b/fraud_case_study/datasources.py new file mode 100644 index 0000000..7383ee5 --- /dev/null +++ b/fraud_case_study/datasources.py @@ -0,0 +1,5 @@ +from chalk.sql import SnowflakeSource +from chalk.streams import KafkaSource + +kafka = KafkaSource(name="txns_data_stream", topic="transactions") +snowflake = SnowflakeSource(name="user_db") diff --git a/fraud_case_study/features.py b/fraud_case_study/features.py new file mode 100644 index 0000000..9b47152 --- /dev/null +++ b/fraud_case_study/features.py @@ -0,0 +1,43 @@ +from chalk.features import features, FeatureTime +from chalk.streams import Windowed, windowed + + +@features +class User: + id: int + first_name: str + last_name: str + email: str + address: str + country_of_residence: str + + # these features are aggregated over the last 7, 30, and 90 days + avg_txn_amount: Windowed[float] = windowed("7d", "30d", "90d") + num_overdrafts: Windowed[int] = windowed("7d", "30d", "90d") + + risk_score: float + + # transactions consists all Transaction rows that are joined to User + # by transaction.user_id + transactions: DataFrame["Transaction"] + + +@features +class Transaction: + # these features are loaded directly from the kafka data source + id: int + user_id: User.id + ts: FeatureTime + vendor: str + description: str + amount: float + country: string + is_overdraft: bool + + # we compute this feature using transaction.country and transaction.user.country_of_residence + in_foreign_country: bool = _.country == _.user.country_of_residence + + + + + diff --git a/fraud_case_study/kafka_resolver.py b/fraud_case_study/kafka_resolver.py new file mode 100644 index 0000000..86b79ac --- /dev/null +++ b/fraud_case_study/kafka_resolver.py @@ -0,0 +1,36 @@ +from pydantic import BaseModel +from datasources import kafka +from chalk.streams import stream + +# Pydantic models define the schema of the messages on the stream. +class TransactionMessage(BaseModel): + id: int + user_id: int + timestamp: datetime + vendor: str + description: str + amount: float + country: str + is_overdraft: bool + +@stream(source=kafka) +def stream_resolver(message: TransactionMessage) -> Features[ + Transaction.id, + Transaction.user_id, + Transaction.timestamp, + Transaction.vendor, + Transaction.description, + Transaction.amount, + Transaction.country, + Transaction.is_overdraft +]: + return Transaction( + id=message.id, + user_id=message.user_id + ts=message.timestamp, + vendor=message.vendor, + description=message.description, + amount=message.amount, + country=message.country, + is_overdraft=message.is_overdraft + ) diff --git a/fraud_case_study/other_resolvers.py b/fraud_case_study/other_resolvers.py new file mode 100644 index 0000000..b124562 --- /dev/null +++ b/fraud_case_study/other_resolvers.py @@ -0,0 +1,44 @@ +from chalk import online, DataFrame +from kafka_resolver import TransactionMessage +from risk import riskclient + + +@online +def get_avg_txn_amount(txns: DataFrame[TransactionMessage]) -> DataFrame[User.id, User.avg_txn_amount]: + # we define a simple aggregation to calculate the average transaction amount + # using SQL syntax (https://docs.chalk.ai/docs/aggregations#using-sql) + # the time filter is pushed down based on the window definition of the feature + return f""" + select + user_id as id, + avg(amount) as avg_txn_amount + from {txns} + group by 1 + """ + +@online +def get_num_overdrafts(txns: DataFrame[TransactionMessage]) -> DataFrame[User.id, User.num_overdrafts]: + # we define a simple aggregation to calculate the number of overdrafts + # using SQL syntax (https://docs.chalk.ai/docs/aggregations#using-sql) + # the time filter is pushed down based on the window definition of the feature + return f""" + select + user_id as id, + count(*) as num_overdrafts + from {txns} + where is_overdraft = 1 + group by 1 + """ + +@online +def get_risk_score( + first_name: User.first_name, + last_name: User.last_name, + email: User.email, + address: User.address +) -> User.risk_score: + # we call our internal Risk API to fetch a user's latest calculated risk score + # based on their personal information + riskclient = riskclient.RiskClient() + return riskclient.get_risk_score(first_name, last_name, email, address) + diff --git a/fraud_case_study/users.chalk.sql b/fraud_case_study/users.chalk.sql new file mode 100644 index 0000000..86276eb --- /dev/null +++ b/fraud_case_study/users.chalk.sql @@ -0,0 +1,5 @@ +-- resolves: User +-- source: user_db + +select id, email, first_name, last_name, address, country_of_residence +from user_db