diff --git a/ generative_ai/sm-huggingface_text_classification.ipynb b/ generative_ai/sm-huggingface_text_classification.ipynb new file mode 100644 index 0000000000..5be335e378 --- /dev/null +++ b/ generative_ai/sm-huggingface_text_classification.ipynb @@ -0,0 +1,943 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ๐Ÿค— Fine-Tuning HuggingFace Models on Amazon SageMaker\n", + "\n", + "## Complete Tutorial for Text Classification\n", + "\n", + "**SageMaker DLC:** PyTorch 2.5.1 + Transformers 4.49.0\n", + "\n", + "---\n", + "\n", + "### ๐Ÿ“š Quick Links\n", + "\n", + "| Resource | Link |\n", + "|----------|------|\n", + "| [AWS Deep Learning Containers](https://github.com/aws/deep-learning-containers/blob/master/available_images.md) | All available DLC images |\n", + "| [HuggingFace Model Hub](https://huggingface.co/models) | Browse 400k+ models |\n", + "| [HuggingFace Datasets](https://huggingface.co/datasets) | Browse 100k+ datasets |\n", + "| [SageMaker HuggingFace SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/index.html) | SDK docs |\n", + "| [SageMaker Pricing](https://aws.amazon.com/sagemaker/pricing/) | Instance pricing |\n", + "| [Transformers Docs](https://huggingface.co/docs/transformers/) | API docs |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ๐ŸŽฏ Tutorial Workflow\n", + "\n", + "```\n", + "โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”\n", + "โ”‚ TUTORIAL WORKFLOW โ”‚\n", + "โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค\n", + "โ”‚ โ”‚\n", + "โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚\n", + "โ”‚ โ”‚ Part 1 โ”‚โ”€โ”€โ”€โ–ถโ”‚ Part 2 โ”‚โ”€โ”€โ”€โ–ถโ”‚ Part 3 โ”‚โ”€โ”€โ”€โ–ถโ”‚ Part 4 โ”‚ โ”‚\n", + "โ”‚ โ”‚ Setup โ”‚ โ”‚ Data โ”‚ โ”‚ Script โ”‚ โ”‚ Train โ”‚ โ”‚\n", + "โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚\n", + "โ”‚ โ”‚ โ”‚ โ”‚\n", + "โ”‚ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ โ”‚\n", + "โ”‚ โ”‚ โ”‚ Model Artifacts (S3) โ”‚โ—€โ”˜ โ”‚\n", + "โ”‚ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚\n", + "โ”‚ โ”‚ โ”‚ โ”‚\n", + "โ”‚ โ–ผ โ–ผ โ”‚\n", + "โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚\n", + "โ”‚ โ”‚ Part 5 โ”‚โ—€โ”€โ”€โ”€โ”‚ Part 6 โ”‚โ—€โ”€โ”€โ”€โ”‚ Part 7 โ”‚โ”€โ”€โ”€โ–ถโ”‚ Part 8 โ”‚ โ”‚\n", + "โ”‚ โ”‚ Deploy โ”‚ โ”‚ Inferenceโ”‚ โ”‚ Advanced โ”‚ โ”‚ Cleanup โ”‚ โ”‚\n", + "โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚\n", + "โ”‚ โ”‚\n", + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n", + "```\n", + "\n", + "### ๐Ÿค– Supported Models\n", + "\n", + "| Model | ID | Params | Model Card |\n", + "|-------|-----|--------|------------|\n", + "| BERT Base | `bert-base-uncased` | 110M | [Link](https://huggingface.co/bert-base-uncased) |\n", + "| RoBERTa Base | `roberta-base` | 125M | [Link](https://huggingface.co/roberta-base) |\n", + "| DistilBERT | `distilbert-base-uncased` | 66M | [Link](https://huggingface.co/distilbert-base-uncased) |\n", + "| DeBERTa v3 | `microsoft/deberta-v3-base` | 184M | [Link](https://huggingface.co/microsoft/deberta-v3-base) |\n", + "| ELECTRA | `google/electra-base-discriminator` | 110M | [Link](https://huggingface.co/google/electra-base-discriminator) |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Part 1: Environment Setup\n", + "\n", + "๐Ÿ“– **Docs:** [SageMaker SDK](https://sagemaker.readthedocs.io/en/stable/) | [Transformers](https://huggingface.co/docs/transformers/installation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sagemaker==2.255.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sagemaker\n", + "import boto3\n", + "import os\n", + "from datetime import datetime\n", + "from sagemaker.huggingface import HuggingFace, HuggingFaceModel\n", + "from datasets import load_dataset, DatasetDict\n", + "from transformers import AutoTokenizer\n", + "\n", + "# Session setup\n", + "sagemaker_session = sagemaker.Session()\n", + "role = sagemaker.get_execution_role()\n", + "region = sagemaker_session.boto_region_name\n", + "bucket = sagemaker_session.default_bucket()\n", + "\n", + "print(f\"๐Ÿ“ Region: {region}\")\n", + "print(f\"Execution Role: {role}\")\n", + "print(f\"๐Ÿชฃ Bucket: {bucket}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Container Versions\n", + "\n", + "๐Ÿ“– **Find versions:** [AWS DLC Images](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)\n", + "\n", + "| PyTorch | Transformers | Python | Status |\n", + "|---------|--------------|--------|--------|\n", + "| **2.5.1** | **4.49.0** | py311 | โœ… Latest |\n", + "| 2.1.0 | 4.36.0 | py310 | Supported |" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… Model: bert-base-uncased\n", + "๐Ÿ“– Card: https://huggingface.co/bert-base-uncased\n" + ] + } + ], + "source": [ + "# Configuration\n", + "MODELS = {\n", + " \"bert-base\": \"bert-base-uncased\",\n", + " \"roberta-base\": \"roberta-base\",\n", + " \"distilbert\": \"distilbert-base-uncased\",\n", + " \"deberta-v3\": \"microsoft/deberta-v3-base\",\n", + "}\n", + "\n", + "SELECTED_MODEL = \"bert-base\"\n", + "MODEL_NAME = MODELS[SELECTED_MODEL]\n", + "\n", + "HYPERPARAMETERS = {\n", + " \"epochs\": 3,\n", + " \"train_batch_size\": 16,\n", + " \"learning_rate\": 2e-5,\n", + " \"max_length\": 128,\n", + " \"model_name\": MODEL_NAME,\n", + "}\n", + "\n", + "TRAINING_INSTANCE = \"ml.p3.2xlarge\"\n", + "INFERENCE_INSTANCE = \"ml.g4dn.xlarge\"\n", + "S3_PREFIX = \"hf-tutorial\"\n", + "TIMESTAMP = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", + "\n", + "print(f\"โœ… Model: {MODEL_NAME}\")\n", + "print(f\"๐Ÿ“– Card: https://huggingface.co/{MODEL_NAME}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Part 2: Data Preparation\n", + "\n", + "### Data Pipeline Overview\n", + "\n", + "```\n", + "โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”\n", + "โ”‚ Data Pipeline โ”‚\n", + "โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค\n", + "โ”‚ โ”‚\n", + "โ”‚ 1. Load & Tokenize 2. Save (Arrow) 3. Upload to S3 โ”‚\n", + "โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚\n", + "โ”‚ load_dataset() โ†’ save_to_disk() โ†’ aws s3 sync โ”‚\n", + "โ”‚ AutoTokenizer() train_data/ s3://bucket/train/ โ”‚\n", + "โ”‚ โ”œโ”€โ”€ data.arrow โ”œโ”€โ”€ data.arrow โ”‚\n", + "โ”‚ โ”œโ”€โ”€ dataset_info.json โ”œโ”€โ”€ dataset_info.json โ”‚\n", + "โ”‚ โ””โ”€โ”€ state.json โ””โ”€โ”€ state.json โ”‚\n", + "โ”‚ โ”‚\n", + "โ”‚ 4. Training Container โ”‚\n", + "โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚\n", + "โ”‚ SageMaker downloads S3 โ†’ /opt/ml/input/data/train/ โ”‚\n", + "โ”‚ train.py calls: load_from_disk(\"/opt/ml/input/data/train\") โ”‚\n", + "โ”‚ โ”‚\n", + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n", + "```\n", + "\n", + "๐Ÿ“– **Datasets:** [HuggingFace Hub](https://huggingface.co/datasets)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… Dataset: ag_news\n", + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 120000\n", + " })\n", + " test: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 7600\n", + " })\n", + "})\n" + ] + } + ], + "source": [ + "# Load dataset\n", + "DATASETS = {\n", + " \"imdb\": {\"name\": \"imdb\", \"text\": \"text\", \"label\": \"label\", \"num_labels\": 2},\n", + " \"sst2\": {\n", + " \"name\": \"glue\",\n", + " \"config\": \"sst2\",\n", + " \"text\": \"sentence\",\n", + " \"label\": \"label\",\n", + " \"num_labels\": 2,\n", + " },\n", + " \"ag_news\": {\"name\": \"ag_news\", \"text\": \"text\", \"label\": \"label\", \"num_labels\": 4},\n", + "}\n", + "\n", + "SELECTED_DATASET = \"ag_news\"\n", + "ds_config = DATASETS[SELECTED_DATASET]\n", + "\n", + "if \"config\" in ds_config:\n", + " raw_dataset = load_dataset(ds_config[\"name\"], ds_config[\"config\"])\n", + "else:\n", + " raw_dataset = load_dataset(ds_config[\"name\"])\n", + "\n", + "print(f\"โœ… Dataset: {SELECTED_DATASET}\")\n", + "print(raw_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tokenization\n", + "\n", + "```\n", + "Original: \"This movie was great!\"\n", + " โ†“\n", + "Tokens: [CLS] this movie was great ! [SEP] [PAD] ...\n", + "IDs: [ 101, 2023, 3185, 2001, 2307, 999, 102, 0, ...]\n", + "Attention:[ 1, 1, 1, 1, 1, 1, 1, 0, ...]\n", + "```\n", + "\n", + "๐Ÿ“– **Docs:** [Tokenizers Guide](https://huggingface.co/docs/transformers/tokenizer_summary)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3a952288ea6848feb62beb253410a45a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/7600 [00:001 req/sec โ”‚ <1 req/sec โ”‚ Long running โ”‚ โ”‚\n", + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿš€ Deploying to ml.g4dn.xlarge...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:sagemaker:Creating model with name: huggingface-pytorch-inference-2025-12-07-22-30-03-212\n", + "INFO:sagemaker:Creating endpoint-config with name hf-bert-base-20251207-221206\n", + "INFO:sagemaker:Creating endpoint with name hf-bert-base-20251207-221206\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-----------!โœ… Endpoint: hf-bert-base-20251207-221206\n" + ] + } + ], + "source": [ + "from sagemaker.huggingface import HuggingFaceModel\n", + "\n", + "# Create model with inference-compatible versions\n", + "huggingface_model = HuggingFaceModel(\n", + " model_data=model_artifacts,\n", + " role=role,\n", + " transformers_version=\"4.49.0\",\n", + " pytorch_version=\"2.6.0\", # Use 2.6.0 for inference (not 2.5.1)\n", + " py_version=\"py312\", # Use py312 for inference (not py311)\n", + ")\n", + "\n", + "# Deploy\n", + "print(f\"๐Ÿš€ Deploying to {INFERENCE_INSTANCE}...\")\n", + "predictor = huggingface_model.deploy(\n", + " initial_instance_count=1,\n", + " instance_type=INFERENCE_INSTANCE,\n", + " endpoint_name=f\"hf-{SELECTED_MODEL}-{TIMESTAMP}\",\n", + ")\n", + "\n", + "print(f\"โœ… Endpoint: {predictor.endpoint_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Part 6: Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿ“Š Dataset: ag_news\n", + "๐Ÿท๏ธ Labels: {0: 'World ๐ŸŒ', 1: 'Sports โšฝ', 2: 'Business ๐Ÿ’ผ', 3: 'Sci/Tech ๐Ÿ”ฌ'}\n", + "\n", + "๐Ÿ”ฎ Predictions:\n", + "\n", + "'The stock market rallied today as tech companies r...' โ†’ Sci/Tech ๐Ÿ”ฌ (66.7%)\n", + "'The championship game ended with a stunning last-m...' โ†’ Sports โšฝ (98.9%)\n", + "'Scientists discover high-frequency brainwaves cont...' โ†’ Sci/Tech ๐Ÿ”ฌ (98.2%)\n", + "'Political leaders from 50 countries met at the UN ...' โ†’ World ๐ŸŒ (98.8%)\n" + ] + } + ], + "source": [ + "# Label mappings for all supported datasets\n", + "LABEL_MAPPINGS = {\n", + " \"imdb\": {0: \"Negative ๐Ÿ‘Ž\", 1: \"Positive ๐Ÿ‘\"},\n", + " \"sst2\": {0: \"Negative ๐Ÿ‘Ž\", 1: \"Positive ๐Ÿ‘\"},\n", + " \"ag_news\": {0: \"World ๐ŸŒ\", 1: \"Sports โšฝ\", 2: \"Business ๐Ÿ’ผ\", 3: \"Sci/Tech ๐Ÿ”ฌ\"},\n", + " \"emotion\": {\n", + " 0: \"Sadness ๐Ÿ˜ข\",\n", + " 1: \"Joy ๐Ÿ˜Š\",\n", + " 2: \"Love โค๏ธ\",\n", + " 3: \"Anger ๐Ÿ˜ \",\n", + " 4: \"Fear ๐Ÿ˜จ\",\n", + " 5: \"Surprise ๐Ÿ˜ฒ\",\n", + " },\n", + " \"yelp\": {0: \"Negative ๐Ÿ‘Ž\", 1: \"Positive ๐Ÿ‘\"},\n", + "}\n", + "\n", + "# Test samples for each dataset type\n", + "TEST_SAMPLES = {\n", + " \"imdb\": [\n", + " \"This movie was absolutely fantastic!\",\n", + " \"Terrible experience, very disappointed.\",\n", + " \"It was okay, nothing special.\",\n", + " ],\n", + " \"sst2\": [\n", + " \"This movie was absolutely fantastic!\",\n", + " \"Terrible experience, very disappointed.\",\n", + " \"It was okay, nothing special.\",\n", + " ],\n", + " \"ag_news\": [\n", + " \"The stock market rallied today as tech companies reported strong earnings.\",\n", + " \"The championship game ended with a stunning last-minute goal.\",\n", + " \"Scientists discover high-frequency brainwaves control memory.\",\n", + " \"Political leaders from 50 countries met at the UN summit.\",\n", + " ],\n", + " \"emotion\": [\n", + " \"I just got promoted at work, this is amazing!\",\n", + " \"I can't believe they canceled my favorite show.\",\n", + " \"You mean everything to me, I'm so grateful.\",\n", + " ],\n", + "}\n", + "\n", + "# Use the correct labels and samples for your dataset\n", + "LABELS = LABEL_MAPPINGS.get(SELECTED_DATASET, {0: \"Class 0\", 1: \"Class 1\"})\n", + "tests = TEST_SAMPLES.get(SELECTED_DATASET, [\"Test sentence\"])\n", + "\n", + "print(f\"๐Ÿ“Š Dataset: {SELECTED_DATASET}\")\n", + "print(f\"๐Ÿท๏ธ Labels: {LABELS}\\n\")\n", + "print(\"๐Ÿ”ฎ Predictions:\\n\")\n", + "\n", + "for text in tests:\n", + " result = predictor.predict({\"inputs\": text})\n", + " if isinstance(result, list):\n", + " label = result[0].get(\"label\", \"LABEL_0\")\n", + " score = result[0].get(\"score\", 0)\n", + " idx = int(label.replace(\"LABEL_\", \"\"))\n", + " print(f\"'{text[:50]}...' โ†’ {LABELS.get(idx, label)} ({score:.1%})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Expected output for AG News:**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "๐Ÿ“Š Dataset: ag_news\n", + "๐Ÿท๏ธ Labels: {0: 'World ๐ŸŒ', 1: 'Sports โšฝ', 2: 'Business ๐Ÿ’ผ', 3: 'Sci/Tech ๐Ÿ”ฌ'}\n", + "\n", + "๐Ÿ”ฎ Predictions:\n", + "\n", + "'The stock market rallied today as tech companies r...' โ†’ Business ๐Ÿ’ผ (94.2%)\n", + "'The championship game ended with a stunning last-m...' โ†’ Sports โšฝ (97.8%)\n", + "'Scientists discover high-frequency brainwaves cont...' โ†’ Sci/Tech ๐Ÿ”ฌ (91.5%)\n", + "'Political leaders from 50 countries met at the UN ...' โ†’ World ๐ŸŒ (88.3%)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Part 7: Cleanup\n", + "\n", + "โš ๏ธ **Delete endpoints to avoid charges!**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"๐Ÿ—‘๏ธ Deleting: {predictor.endpoint_name}\")\n", + "predictor.delete_endpoint()\n", + "print(\"โœ… Deleted!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## ๐Ÿ“š Summary\n", + "\n", + "### Key Concepts\n", + "\n", + "| Concept | What | Why |\n", + "|---------|------|-----|\n", + "| Arrow Format | Binary columnar data format | Fast loading (10GB in <1 sec) |\n", + "| `save_to_disk()` | Saves dataset as Arrow files | Preserves tokenization |\n", + "| `load_from_disk()` | Reads Arrow files in training script | Must match save format |\n", + "| Container Versions | Training vs Inference may differ | Check DLC availability |\n", + "\n", + "### Resources\n", + "\n", + "| Resource | Link |\n", + "|----------|------|\n", + "| AWS DLC Images | https://github.com/aws/deep-learning-containers/blob/master/available_images.md |\n", + "| SageMaker SDK | https://sagemaker.readthedocs.io/ |\n", + "| HuggingFace Docs | https://huggingface.co/docs/transformers/ |\n", + "| Model Hub | https://huggingface.co/models |\n", + "| Datasets Hub | https://huggingface.co/datasets |\n", + "| Pricing | https://aws.amazon.com/sagemaker/pricing/ |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/index.rst b/index.rst index f118448289..420411deff 100644 --- a/index.rst +++ b/index.rst @@ -168,6 +168,7 @@ We recommend the following notebooks as a broad introduction to the capabilities generative_ai/sm-mixtral_8x7b_fine_tune_and_deploy/sm-mixtral_8x7b_fine_tune_and_deploy generative_ai/sm-djl_deepspeed_bloom_176b_deploy generative_ai/sm-fsdp_training_of_llama_v2_with_fp8_on_p5 + generative_ai/sm-huggingface_text_classification generative_ai/sm-jumpstart_foundation_code_llama_fine_tuning_human_eval generative_ai/sm-jumpstart_foundation_finetuning_gpt_j_6b_domain_adaptation generative_ai/sm-jumpstart_foundation_gemma_fine_tuning