diff --git a/notebooks/explainability-monitoring-promo-planning.ipynb b/notebooks/explainability-monitoring-promo-planning.ipynb new file mode 100644 index 0000000..7b824e3 --- /dev/null +++ b/notebooks/explainability-monitoring-promo-planning.ipynb @@ -0,0 +1,1185 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Feature Attribution Drift Monitoring with AWS SageMaker Clarity" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Jupyter notebook shows how to perform model bias explainability monitoring with AWS SageMaker (based on [docs](https://sagemaker-examples.readthedocs.io/en/latest/sagemaker_model_monitor/fairness_and_explainability/SageMaker-Model-Monitor-Fairness-and-Explainability.html))\n", + "\n", + "\n", + "Aamazon SageMaker Clarify explainability monitoring offers tools to provide global explanations of models and to explain the predictions of a deployed model producing inferences. Such model explanation tools can help ML modelers and developers and other internal stakeholders understand model characteristics as a whole prior to deployment and to debug predictions provided by the model once deployed. The current offering includes a scalable and efficient implementation of [**SHAP**](https://papers.nips.cc/paper/2017/hash/8a20a8621978632d76c43dfd28b67767-Abstract.html), based on the concept of the [**Shapley value**](https://en.wikipedia.org/wiki/Shapley_value) from the field of cooperative game theory that assigns each feature an importance value for a particular prediction.\n", + "\n", + "\n", + "Notebook structure:\n", + "\n", + "## Table of Contents\n", + "1. **[Configuration](#Configuration)** \n", + "1. **[Model and Data Preparation](#Model-and-Data-Preparation)**\n", + "1. **[Deploying model for ML Observability](#Deploying-model-for-ML-Observability)**\n", + "1. **[Generate traffic](#Generate-traffic)** - Provides artifictial traffic for explainability metrics\n", + "1. **[Setting up monitoring job](#Setting-up-monitoring-job)** - Creates Monitoring tasks by creating baseline and scheduling regular monitoring\n", + "1. **[Cleaning up](#Cleaning-up)** - Removes all the created resources.\n", + "\n", + "Prerequisites:\n", + "\n", + "- Existing Roles with all needed permissions (S3, SageMaker, etc.)\n", + "- Configured SageMaker Domain\n", + "- SageMaker Studio user\n", + "- S3 bucket with pretrained model and data\n", + "\n", + "One can use SageMaker Studio or Sagemaker Notebook instances (Jupyter-like environment) to run this notebook. \n", + "To do that, follow the next steps:\n", + "\n", + "1. Run the SageMaker Studio or Create new SageMaker notebook instance\n", + "1. Clone this repository (https://github.com/griddynamics/gd-ml-observability.git)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -q --upgrade boto3" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import copy\n", + "import time\n", + "import pandas as pd\n", + "import threading\n", + "\n", + "from datetime import datetime\n", + "\n", + "from sagemaker import get_execution_role, image_uris, Session\n", + "from sagemaker.clarify import (\n", + " DataConfig,\n", + " ModelConfig,\n", + " SHAPConfig,\n", + ")\n", + "from sagemaker.model import Model\n", + "from sagemaker.model_monitor import (\n", + " CronExpressionGenerator,\n", + " DataCaptureConfig,\n", + " ExplainabilityAnalysisConfig,\n", + " ModelExplainabilityMonitor,\n", + ")\n", + "from sagemaker.predictor import Predictor\n", + "from sagemaker.s3 import S3Downloader, S3Uploader" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Execution Role: arn:aws:iam::125667932402:role/service-role/AmazonSageMaker-ExecutionRole-20230117T121938\n", + "AWS region: us-east-1\n", + "Demo Bucket: sagemaker-us-east-1-125667932402\n", + "S3 key: s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning\n", + "Capture path: s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/datacapture\n", + "Report path: s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/reports\n", + "Baseline results uri: s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/baselining\n" + ] + } + ], + "source": [ + "role = get_execution_role()\n", + "print(f\"Execution Role: {role}\")\n", + "\n", + "sagemaker_session = Session()\n", + "sagemaker_client = sagemaker_session.sagemaker_client\n", + "sagemaker_runtime_client = sagemaker_session.sagemaker_runtime_client\n", + "\n", + "region = sagemaker_session.boto_region_name\n", + "print(f\"AWS region: {region}\")\n", + "\n", + "# A different bucket can be used, but make sure the role for this notebook has\n", + "# the s3:PutObject permissions. This is the bucket into which the data is captured\n", + "bucket = Session().default_bucket()\n", + "print(f\"Demo Bucket: {bucket}\")\n", + "prefix = \"sagemaker/shap-observability-promo-planning\"\n", + "s3_key = f\"s3://{bucket}/{prefix}\"\n", + "print(f\"S3 key: {s3_key}\")\n", + "\n", + "s3_capture_upload_path = f\"{s3_key}/datacapture\"\n", + "s3_report_path = f\"{s3_key}/reports\"\n", + "\n", + "print(f\"Capture path: {s3_capture_upload_path}\")\n", + "print(f\"Report path: {s3_report_path}\")\n", + "\n", + "baseline_results_uri = f\"{s3_key}/baselining\"\n", + "print(f\"Baseline results uri: {baseline_results_uri}\")\n", + "\n", + "endpoint_instance_count = 1\n", + "endpoint_instance_type = \"ml.m5.large\"\n", + "schedule_expression = CronExpressionGenerator.hourly()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Test bucket connectivity" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Success! We are all set to proceed.\n" + ] + } + ], + "source": [ + "# Upload a test file\n", + "test_file = 'upload-test-file.txt'\n", + "with open(test_file, 'w') as f:\n", + " f.write('Hello world!\\n')\n", + "\n", + "S3Uploader.upload(test_file, f\"s3://{bucket}/test_upload\")\n", + "print(\"Success! We are all set to proceed.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model and Data Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the explainability demo promotion planning dataset is used with pre-trained XGBoost model. \n", + "We need to specify S3 URI for pre-trained model and data which will be used to compute baseline data set for shap." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "s3_data_uri = 's3://adp-rnd-ml-datasets/promotion-planning/validation/data.csv'\n", + "s3_model_uri = 's3://adp-rnd-ml-models/promotion-planning/model/promotion-planning-train-job-2023-01-31-084806/output/model.tar.gz'" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "dataset_type = 'text/csv'\n", + "\n", + "model_dir = 'model'\n", + "model_file = f'{model_dir}/model.tar.gz'\n", + "S3Downloader.download(s3_model_uri, model_dir)\n", + "\n", + "\n", + "dataset_dir = 'data'\n", + "S3Downloader.download(s3_data_uri, dataset_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SHAPE: (57834, 111)\n" + ] + }, + { + "data": { + "text/plain": [ + "['OUTPUT_LABEL',\n", + " 'amt',\n", + " 'oft',\n", + " 'amount_365_days_lag',\n", + " 'off_365_days_lag',\n", + " 'black_friday',\n", + " 'business_day',\n", + " 'cyber_monday',\n", + " 'day_of_week',\n", + " 'day_of_month']" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv(f'{dataset_dir}/data.csv')\n", + "print('SHAPE:', df.shape)\n", + "\n", + "all_headers = df.columns.tolist()\n", + "label_header = all_headers[0]\n", + "all_headers[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " For purposes of this demo we won't be using just a fraction of all available data." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test shape: (578, 111)\n", + "Validation shape: (573, 111)\n" + ] + } + ], + "source": [ + "fraction = 0.01\n", + "test_idx = df.sample(frac=fraction).index\n", + "test_data = df[df.index.isin(test_idx)]\n", + "print('Test shape:', test_data.shape)\n", + "test_dataset_dir = 'test'\n", + "test_dataset = f'{test_dataset_dir}/test.csv'\n", + "test_data.drop(label_header, axis=1).to_csv(test_dataset, index=False, header=False)\n", + "\n", + "\n", + "val_data = df[~df.index.isin(test_idx)].sample(frac=fraction)\n", + "print('Validation shape:', val_data.shape)\n", + "validation_dataset_dir = 'validation'\n", + "validation_dataset = f'{validation_dataset_dir}/validation.csv'\n", + "val_data.to_csv(validation_dataset, index=False, header=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model file has been uploaded to s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/model.tar.gz\n" + ] + } + ], + "source": [ + "model_url = S3Uploader.upload(model_file, s3_key)\n", + "print(f\"Model file has been uploaded to {model_url}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deploying model for ML Observability" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setting up the pre-trained model" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model name: shap-observability-promo-planning-2023-03-06-1005\n", + "Endpoint name: shap-observability-promo-planning-2023-03-06-1005\n", + "XGBoost image uri: 683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3\n" + ] + } + ], + "source": [ + "model_name = f\"shap-observability-promo-planning-{datetime.utcnow():%Y-%m-%d-%H%M}\"\n", + "print(\"Model name: \", model_name)\n", + "endpoint_name = f\"shap-observability-promo-planning-{datetime.utcnow():%Y-%m-%d-%H%M}\"\n", + "print(\"Endpoint name: \", endpoint_name)\n", + "\n", + "image_uri = image_uris.retrieve(\"xgboost\", region, '0.90-1')\n", + "print(f\"XGBoost image uri: {image_uri}\")\n", + "model = Model(\n", + " role=role,\n", + " name=model_name,\n", + " image_uri=image_uri,\n", + " model_data=model_url,\n", + " sagemaker_session=sagemaker_session,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setting up data capture config to be able to monitor model bias based on the stored data" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "data_capture_config = DataCaptureConfig(\n", + " enable_capture=True,\n", + " sampling_percentage=100,\n", + " destination_s3_uri=s3_capture_upload_path,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Deploying the model to the endpoint that will be monitored" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deploying model shap-observability-promo-planning-2023-03-06-1005 to endpoint shap-observability-promo-planning-2023-03-06-1005\n", + "------!" + ] + } + ], + "source": [ + "print(f\"Deploying model {model_name} to endpoint {endpoint_name}\")\n", + "model.deploy(\n", + " initial_instance_count=endpoint_instance_count,\n", + " instance_type=endpoint_instance_type,\n", + " endpoint_name=endpoint_name,\n", + " data_capture_config=data_capture_config,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate traffic" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If there is no traffic, the monitoring jobs are marked as Failed since there is no data to process. \n", + "So for this example we generate traffic artificial traffic based on test dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "\n", + "class WorkerThread(threading.Thread):\n", + " def __init__(self, do_run, *args, **kwargs):\n", + " super(WorkerThread, self).__init__(*args, **kwargs)\n", + " self.__do_run = do_run\n", + " self.__terminate_event = threading.Event()\n", + "\n", + " def terminate(self):\n", + " self.__terminate_event.set()\n", + "\n", + " def run(self):\n", + " while not self.__terminate_event.is_set():\n", + " self.__do_run(self.__terminate_event)\n", + "\n", + "\n", + "def invoke_endpoint(terminate_event):\n", + " with open(test_dataset, \"r\") as f:\n", + " i = 0\n", + " for row in f:\n", + " payload = row.rstrip(\"\\n\")\n", + " response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Body=payload,\n", + " InferenceId=str(i), # unique ID per row\n", + " )\n", + " i += 1\n", + " response[\"Body\"].read()\n", + " time.sleep(1)\n", + " if terminate_event.is_set():\n", + " break\n", + "\n", + "\n", + "# Keep invoking the endpoint with test data\n", + "invoke_endpoint_thread = WorkerThread(do_run=invoke_endpoint)\n", + "invoke_endpoint_thread.start()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting up monitoring job" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Creating model explainability monitor" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Explainability baseline s3 uri: s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/baselining/model_explainability\n" + ] + } + ], + "source": [ + "model_explainability_monitor = ModelExplainabilityMonitor(\n", + " role=role,\n", + " sagemaker_session=sagemaker_session,\n", + " max_runtime_in_seconds=1800,\n", + ")\n", + "\n", + "\n", + "model_explainability_baselining_job_result_uri = f\"{baseline_results_uri}/model_explainability\"\n", + "print(f'Explainability baseline s3 uri: {model_explainability_baselining_job_result_uri}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Creating model and data config" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model_config = ModelConfig(\n", + " model_name=model_name,\n", + " instance_count=endpoint_instance_count,\n", + " instance_type=endpoint_instance_type,\n", + " content_type=dataset_type,\n", + " accept_type=dataset_type,\n", + ")\n", + "\n", + "model_explainability_data_config = DataConfig(\n", + " s3_data_input_path=validation_dataset,\n", + " s3_output_path=model_explainability_baselining_job_result_uri,\n", + " label=label_header,\n", + " headers=all_headers,\n", + " dataset_type=dataset_type,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Shap baseline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to create explainability baseline, we need to provide a baseline dataset for Kernel Shap algorithm.\n", + "Number of samples determines the size of generated synthetic dataset (if not provided clarify will choose a value based on number of features). There are three ways to aggregate shap values: `mean_abs` (mean of absolute SHAP values), `median` (median of SHAP values for all instances) and `mean_sq` (mean of squared SHAP values)." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "test_dataframe = pd.read_csv(test_dataset, header=None)\n", + "shap_baseline = test_dataframe.sample(frac=fraction).values.tolist()\n", + "\n", + "\n", + "shap_config = SHAPConfig(\n", + " baseline=shap_baseline,\n", + " num_samples=50,\n", + " agg_method=\"mean_abs\",\n", + " save_local_shap_values=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Run explainability baselining job" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:sagemaker:Creating processing-job with name baseline-suggestion-job-2023-03-06-10-17-13-690\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ModelExplainabilityMonitor baselining job: baseline-suggestion-job-2023-03-06-10-17-13-690\n" + ] + } + ], + "source": [ + "model_explainability_monitor.suggest_baseline(\n", + " data_config=model_explainability_data_config,\n", + " model_config=model_config,\n", + " explainability_config=shap_config,\n", + ")\n", + "latest_baselining_job_name = model_explainability_monitor.latest_baselining_job_name\n", + "print(f\"ModelExplainabilityMonitor baselining job: {latest_baselining_job_name}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Wait for baselining job and review the constaints" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ".......................................................................................................!" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ModelExplainabilityMonitor suggested constraints: s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/baselining/model_explainability/analysis.json\n", + "{\n", + " \"version\": \"1.0\",\n", + " \"explanations\": {\n", + " \"kernel_shap\": {\n", + " \"label0\": {\n", + " \"global_shap_values\": {\n", + " \"amt\": 0.10948012214978607,\n", + " \"oft\": 0.10853867704455387,\n", + " \"amount_365_days_lag\": 0.10292516679605697,\n", + " \"off_365_days_lag\": 0.11545188903392774,\n", + " \"black_friday\": 0.09804559208200446,\n", + " \"business_day\": 0.11297093042708263,\n", + " \"cyber_monday\": 0.11255411381564072,\n", + " \"day_of_week\": 0.1071939282195742,\n", + " \"day_of_month\": 0.11620003500504866,\n", + " \"month\": 0.10880362477807314,\n", + " \"is_holiday\": 0.10026392279858574,\n", + " \"list_price\": 0.506711290138479,\n", + " \"list_price_7_days_lag\": 0.11800608684558543,\n", + " \"list_price_30_days_lag\": 0.21218059951008095,\n", + " \"list_price_365_days_lag\": 0.09798197500285986,\n", + " \"post_holiday\": 0.1098804984022009,\n", + " \"pre_holiday\": 0.10016967365175429,\n", + " \"profit_365_days_lag\": 0.13291067261012388,\n", + " \"quantity_365_days_lag\": 0.14637940625479695,\n", + " \"promo_price\": 0.21770283722827494,\n", + " \"promo_price_7_days_lag\": 0.11190892386092048,\n", + " \"promo_price_30_days_lag\": 0.09842779868694215,\n", + " \"promo_price_365_days_lag\": 0.09502923450685478,\n", + " \"purchase_price\": 0.15528941556558437,\n", + " \"ratio_list_price_to_purchase_price\": 0.12373419983583203,\n", + " \"ratio_list_price_to_purchase_price_7_days_lag\": 0.10299942812377823,\n", + " \"ratio_list_price_to_purchase_price_30_days_lag\": 0.09802656115694341,\n", + " \"ratio_list_price_to_purchase_price_365_days_lag\": 0.10960291324109075,\n", + " \"ratio_promo_price_to_list_price\": 0.23390822838115832,\n", + " \"ratio_promo_price_to_list_price_7_days_lag\": 0.11964902540624262,\n", + " \"ratio_promo_price_to_list_price_30_days_lag\": 0.208234085642571,\n", + " \"ratio_promo_price_to_list_price_365_days_lag\": 0.10114835002583304,\n", + " \"ratio_promo_price_to_purchase_price\": 0.11991915780945882,\n", + " \"ratio_promo_price_to_purchase_price_7_days_lag\": 0.09458811055155714,\n", + " \"ratio_promo_price_to_purchase_price_30_days_lag\": 0.09958696575882661,\n", + " \"ratio_promo_price_to_purchase_price_365_days_lag\": 0.09744384475847295,\n", + " \"super_bowl\": 0.10334844558545071,\n", + " \"season_of_year_one_hot_0\": 0.10322065371209216,\n", + " \"season_of_year_one_hot_1\": 0.0978278709937913,\n", + " \"season_of_year_one_hot_2\": 0.10885422587158512,\n", + " \"season_of_year_one_hot_3\": 0.09546615791048331,\n", + " \"occasion_one_hot_0\": 0.1086362024160909,\n", + " \"occasion_one_hot_1\": 0.10174865677631961,\n", + " \"occasion_one_hot_2\": 0.11505962294717757,\n", + " \"occasion_one_hot_3\": 0.091129324919371,\n", + " \"occasion_one_hot_4\": 0.1086840188270564,\n", + " \"occasion_one_hot_5\": 0.11460496065820619,\n", + " \"occasion_one_hot_6\": 0.10497980358354382,\n", + " \"occasion_one_hot_7\": 0.09815689919396274,\n", + " \"dress_length_one_hot_0\": 0.1002487162101105,\n", + " \"dress_length_one_hot_1\": 0.10383538205895343,\n", + " \"dress_length_one_hot_2\": 0.10312540511246755,\n", + " \"dress_length_one_hot_3\": 0.1043003394041985,\n", + " \"dress_types_one_hot_0\": 0.10400076944028507,\n", + " \"dress_types_one_hot_1\": 0.10138649853273446,\n", + " \"dress_types_one_hot_2\": 0.10930426805410717,\n", + " \"dress_types_one_hot_3\": 0.12152324177662033,\n", + " \"dress_types_one_hot_4\": 0.10714419452367374,\n", + " \"dress_types_one_hot_5\": 0.10404574293893364,\n", + " \"dress_types_one_hot_6\": 0.11036050004825468,\n", + " \"dress_types_one_hot_7\": 0.14217532011890302,\n", + " \"dress_types_one_hot_8\": 0.10985604939356246,\n", + " \"dress_types_one_hot_9\": 0.10875167657463633,\n", + " \"dress_types_one_hot_10\": 0.09929935068210678,\n", + " \"dress_types_one_hot_11\": 0.10816445625172055,\n", + " \"material_one_hot_0\": 0.1169215728312945,\n", + " \"material_one_hot_1\": 0.0966299357681652,\n", + " \"material_one_hot_2\": 0.11408043362923098,\n", + " \"material_one_hot_3\": 0.09676312781703811,\n", + " \"material_one_hot_4\": 0.10583165954531816,\n", + " \"material_one_hot_5\": 0.09682183355584198,\n", + " \"material_one_hot_6\": 0.11162891116148937,\n", + " \"material_one_hot_7\": 0.0999889880261619,\n", + " \"material_one_hot_8\": 0.1008431062402366,\n", + " \"material_one_hot_9\": 0.09252033469402106,\n", + " \"material_one_hot_10\": 0.10140153803449072,\n", + " \"neckline_style_one_hot_0\": 0.10034076070054586,\n", + " \"neckline_style_one_hot_1\": 0.10170723693646244,\n", + " \"neckline_style_one_hot_2\": 0.10712473879055205,\n", + " \"neckline_style_one_hot_3\": 0.10804204168326723,\n", + " \"neckline_style_one_hot_4\": 0.09997906266711504,\n", + " \"neckline_style_one_hot_5\": 0.09219191529749457,\n", + " \"neckline_style_one_hot_6\": 0.09906187353225106,\n", + " \"neckline_style_one_hot_7\": 0.10349327712023981,\n", + " \"color_category_one_hot_0\": 0.17467884345743942,\n", + " \"color_category_one_hot_1\": 0.13175136646218338,\n", + " \"color_category_one_hot_2\": 0.10429479313053264,\n", + " \"color_category_one_hot_3\": 0.10172490686121644,\n", + " \"color_category_one_hot_4\": 0.10297211578117726,\n", + " \"color_category_one_hot_5\": 0.10149597782975311,\n", + " \"color_category_one_hot_6\": 0.10139388201224835,\n", + " \"color_category_one_hot_7\": 0.10880969566521541,\n", + " \"color_category_one_hot_8\": 0.1136614344496582,\n", + " \"color_category_one_hot_9\": 0.09186830208395018,\n", + " \"color_category_one_hot_10\": 0.09771373919144166,\n", + " \"color_category_one_hot_11\": 0.11334586331334309,\n", + " \"color_category_one_hot_12\": 0.10017991258039771,\n", + " \"color_category_one_hot_13\": 0.11312960236694848,\n", + " \"color_category_one_hot_14\": 0.11541271978929242,\n", + " \"color_category_one_hot_15\": 0.10327216093405615,\n", + " \"size_one_hot_0\": 0.10785212146689158,\n", + " \"size_one_hot_1\": 0.11141512747936261,\n", + " \"size_one_hot_2\": 0.10979325922577386,\n", + " \"size_one_hot_3\": 0.10691508607399734,\n", + " \"size_one_hot_4\": 0.1067305058871434,\n", + " \"size_one_hot_5\": 0.0971575940917513,\n", + " \"size_one_hot_6\": 0.11754653617217176,\n", + " \"size_one_hot_7\": 0.3089827303796554,\n", + " \"size_one_hot_8\": 0.11105645757546824,\n", + " \"size_one_hot_9\": 0.7021859316571338\n", + " },\n", + " \"expected_value\": 0.42680683235327405\n", + " }\n", + " }\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "model_explainability_monitor.latest_baselining_job.wait(logs=False)\n", + "model_explainability_constraints = model_explainability_monitor.suggested_constraints()\n", + "if model_explainability_constraints is not None:\n", + " print(\n", + " \"ModelExplainabilityMonitor suggested constraints: \"\n", + " f\"{model_explainability_constraints.file_s3_uri}\"\n", + " )\n", + " print(S3Downloader.read_file(model_explainability_constraints.file_s3_uri))\n", + "\n", + "model_explainability_analysis_config = None\n", + "if not model_explainability_monitor.latest_baselining_job:\n", + " # Remove label because only features are required for the analysis\n", + " headers_without_label_header = copy.deepcopy(all_headers)\n", + " headers_without_label_header.remove(label_header)\n", + " model_explainability_analysis_config = ExplainabilityAnalysisConfig(\n", + " explainability_config=shap_config,\n", + " model_config=model_config,\n", + " headers=headers_without_label_header,\n", + " ) \n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Monitoring schedule" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:sagemaker.model_monitor.clarify_model_monitoring:Uploading analysis config to {s3_uri}.\n", + "INFO:sagemaker.model_monitor.model_monitoring:Creating Monitoring Schedule with name: monitoring-schedule-2023-03-06-10-26-10-817\n" + ] + } + ], + "source": [ + "model_explainability_monitor.create_monitoring_schedule(\n", + " output_s3_uri=s3_report_path,\n", + " endpoint_input=endpoint_name,\n", + " schedule_cron_expression=schedule_expression,\n", + " analysis_config=model_explainability_analysis_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can find created monitoring schedule in SageMaker Studio by navigating to Deployments -> Endpoints section in sidebar, choosing you endpoint and opening `Model explainability` tab.\n", + "![Schedule](images/sm-shap-monitoring-schedule.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Wait for monitoring job to start first execution" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def wait_for_execution_to_start(model_monitor):\n", + " print(\n", + " \"A hourly schedule was created above and it will kick off executions ON the hour (plus 0 - 20 min buffer).\"\n", + " )\n", + "\n", + " print(\"Waiting for the first execution to happen\", end=\"\")\n", + " schedule_desc = model_monitor.describe_schedule()\n", + " while \"LastMonitoringExecutionSummary\" not in schedule_desc:\n", + " schedule_desc = model_monitor.describe_schedule()\n", + " print(\".\", end=\"\", flush=True)\n", + " time.sleep(60)\n", + " print()\n", + " print(\"Done! Execution has been created\")\n", + "\n", + " print(\"Now waiting for execution to start\", end=\"\")\n", + " while schedule_desc[\"LastMonitoringExecutionSummary\"][\"MonitoringExecutionStatus\"] in \"Pending\":\n", + " schedule_desc = model_monitor.describe_schedule()\n", + " print(\".\", end=\"\", flush=True)\n", + " time.sleep(10)\n", + "\n", + " print()\n", + " print(\"Done! Execution has started\")\n", + "\n", + "\n", + "\n", + "# Waits for the schedule to have last execution in a terminal status.\n", + "def wait_for_execution_to_finish(model_monitor):\n", + " schedule_desc = model_monitor.describe_schedule()\n", + " execution_summary = schedule_desc.get(\"LastMonitoringExecutionSummary\")\n", + " if execution_summary is not None:\n", + " print(\"Waiting for execution to finish\", end=\"\")\n", + " while execution_summary[\"MonitoringExecutionStatus\"] not in [\n", + " \"Completed\",\n", + " \"CompletedWithViolations\",\n", + " \"Failed\",\n", + " \"Stopped\",\n", + " ]:\n", + " print(\".\", end=\"\", flush=True)\n", + " time.sleep(60)\n", + " schedule_desc = model_monitor.describe_schedule()\n", + " execution_summary = schedule_desc[\"LastMonitoringExecutionSummary\"]\n", + " print()\n", + " print(\"Done! Execution has finished\")\n", + " else:\n", + " print(\"Last execution not found\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A hourly schedule was created above and it will kick off executions ON the hour (plus 0 - 20 min buffer).\n", + "Waiting for the first execution to happen\n", + "Done! Execution has been created\n", + "Now waiting for execution to start\n", + "Done! Execution has started\n" + ] + } + ], + "source": [ + "wait_for_execution_to_start(model_explainability_monitor)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Waiting for execution to finish..\n", + "Done! Execution has finished\n" + ] + } + ], + "source": [ + "wait_for_execution_to_finish(model_explainability_monitor)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Report URI: s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/reports/shap-observability-promo-planning-2023-03-06-1005/monitoring-schedule-2023-03-06-10-26-10-817/2023/03/06/17\n", + "Found Report Files:\n", + "s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/reports/shap-observability-promo-planning-2023-03-06-1005/monitoring-schedule-2023-03-06-10-26-10-817/2023/03/06/17/analysis.json\n", + " s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/reports/shap-observability-promo-planning-2023-03-06-1005/monitoring-schedule-2023-03-06-10-26-10-817/2023/03/06/17/report.html\n", + " s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/reports/shap-observability-promo-planning-2023-03-06-1005/monitoring-schedule-2023-03-06-10-26-10-817/2023/03/06/17/report.ipynb\n", + " s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/reports/shap-observability-promo-planning-2023-03-06-1005/monitoring-schedule-2023-03-06-10-26-10-817/2023/03/06/17/report.pdf\n" + ] + } + ], + "source": [ + "\n", + "schedule_desc = model_explainability_monitor.describe_schedule()\n", + "execution_summary = schedule_desc.get(\"LastMonitoringExecutionSummary\")\n", + "if execution_summary and execution_summary[\"MonitoringExecutionStatus\"] in [\n", + " \"Completed\",\n", + " \"CompletedWithViolations\",\n", + "]:\n", + " last_model_explainability_monitor_execution = model_explainability_monitor.list_executions()[-1]\n", + " last_model_explainability_monitor_execution_report_uri = (\n", + " last_model_explainability_monitor_execution.output.destination\n", + " )\n", + " print(f\"Report URI: {last_model_explainability_monitor_execution_report_uri}\")\n", + " last_model_explainability_monitor_execution_report_files = sorted(\n", + " S3Downloader.list(last_model_explainability_monitor_execution_report_uri)\n", + " )\n", + " print(\"Found Report Files:\")\n", + " print(\"\\n \".join(last_model_explainability_monitor_execution_report_files))\n", + "else:\n", + " last_model_explainability_monitor_execution = None\n", + " print(\n", + " \"====STOP==== \\n No completed executions to inspect further. Please wait till an execution completes or investigate previously reported failures.\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Could not retrieve constraints file at location 's3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/reports/shap-observability-promo-planning-2023-03-06-1005/monitoring-schedule-2023-03-06-10-26-10-817/2023/03/06/17/constraint_violations.json'. To manually retrieve ConstraintViolations object from a given uri, use 'my_model_monitor.constraints(my_s3_uri)' or 'ConstraintViolations.from_s3_uri(my_s3_uri)'\n" + ] + } + ], + "source": [ + "if last_model_explainability_monitor_execution:\n", + " model_explainability_violations = (\n", + " last_model_explainability_monitor_execution.constraint_violations()\n", + " )\n", + " if model_explainability_violations:\n", + " print(model_explainability_violations.body_dict)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now the feature importance chart is diplayed in SageMaker Studio\n", + "![Chart](images/sm-shap-chart.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also view previous Explainability monitoring jobs in `Monitoring job history` tab.\n", + "![Monitoring Jobs History](images/sm-shap-monitoring-job-history.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sagemaker also provides reports, which are saved to S3 to specified `s3_report_path` in multiple formats including HTML and PDF.\n", + "![Report](images/sm-shap-explainability-report.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Report path: s3://sagemaker-us-east-1-125667932402/sagemaker/shap-observability-promo-planning/reports\n" + ] + } + ], + "source": [ + "print(f\"Report path: {s3_report_path}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleaning up" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If there is no plan to use the endpoint further, it should be deleted to avoid incurring additional charges. Note that deleting endpoint does not delete the data that was captured during the model invocations." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Stopping Monitoring Schedule with name: monitoring-schedule-2023-03-06-10-26-10-817\n", + "\n", + "Deleting Monitoring Schedule with name: monitoring-schedule-2023-03-06-10-26-10-817\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:sagemaker.model_monitor.clarify_model_monitoring:Deleting Model Explainability Job Definition with name: model-explainability-job-definition-2023-03-06-10-26-10-817\n" + ] + } + ], + "source": [ + "\n", + "model_explainability_monitor.stop_monitoring_schedule()\n", + "model_explainability_monitor.delete_monitoring_schedule()" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "invoke_endpoint_thread.terminate()\n", + "\n", + "predictor = Predictor(endpoint_name, sagemaker_session=sagemaker_session)\n", + "predictor.delete_endpoint()\n", + "predictor.delete_model()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "observability", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "vscode": { + "interpreter": { + "hash": "af6718834c1df5537d44ab7efc014a903af882c9f942fc4839abc2d07dc9a719" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/images/sm-shap-chart.png b/notebooks/images/sm-shap-chart.png new file mode 100644 index 0000000..f3db30e Binary files /dev/null and b/notebooks/images/sm-shap-chart.png differ diff --git a/notebooks/images/sm-shap-explainability-report.png b/notebooks/images/sm-shap-explainability-report.png new file mode 100644 index 0000000..011a8db Binary files /dev/null and b/notebooks/images/sm-shap-explainability-report.png differ diff --git a/notebooks/images/sm-shap-monitoring-job-history.png b/notebooks/images/sm-shap-monitoring-job-history.png new file mode 100644 index 0000000..c9bb80a Binary files /dev/null and b/notebooks/images/sm-shap-monitoring-job-history.png differ diff --git a/notebooks/images/sm-shap-monitoring-schedule.png b/notebooks/images/sm-shap-monitoring-schedule.png new file mode 100644 index 0000000..d806723 Binary files /dev/null and b/notebooks/images/sm-shap-monitoring-schedule.png differ