Getting Started with Intel’s PyTorch Extension for Arc GPUs on Windows

pytorch
windows
image-classification
arc-gpu
getting-started
This tutorial provides a step-by-step guide to setting up Intel’s PyTorch extension on Windows to train models with Arc GPUs.
Author

Christian Mills

Published

September 9, 2023

Introduction

In this tutorial, I’ll guide you through setting up Intel’s PyTorch extension on Windows to train models with their Arc GPUs. The extension provides Intel’s latest feature optimizations and hardware support before they get added to PyTorch. Most importantly for our case, it now includes experimental support for Intel’s Arc GPUs and optimizations to take advantage of their Xe Matrix Extensions (XMX).

The XMX engines are dedicated hardware for performing matrix operations like those in deep-learning workloads. Intel’s PyTorch extension allows us to leverage this hardware with minimal changes to existing PyTorch code.

To illustrate this, we’ll adapt the training code from my beginner-level PyTorch tutorial, where we fine-tune an image classification model from the timm library for hand gesture recognition. By the end of this tutorial, you’ll know all steps required to set up Windows for training PyTorch models using Arc GPUs.

Note

The current setup process is for version 2.0.110+xpu of Intel’s PyTorch extension.

Enable Resizable BAR in BIOS

If you have an Arc GPU, one of the first things you should do is enable Resizable BAR. Resizable BAR allows a computer’s processor to access the graphics card’s entire memory instead of in small chunks. The Arc GPUs currently require this feature to perform as intended. You can enable the feature in your motherboard’s BIOS.

Here are links on how to do this for some of the popular motherboard manufacturers:

With Resizable BAR enabled, let’s ensure we have the latest drivers for our Arc GPU.

Install Drivers

We can download the latest Arc GPU drivers from Intel’s website at the link below:

The latest driver version available was 31.0.101.4676 at the time of writing. Click the Download button under Available Downloads to download the installer.

Once downloaded, double-click the installer executable and follow the prompts to install the drivers.

You don’t need to agree to join the Intel Computing Improvement program to install the drivers.

Once the installation completes, click the Reboot Recommended button to reboot the computer.

We can continue with the next step once we’re back in Windows.

Install Visual Studio

The oneAPI Base Toolkit, which Intel’s extension depends on, requires Visual Studio with the Desktop development with C++ workload for some of its packages to function. We can download the installer for the free Community version of Visual Studio at the link below:

There is a compatibility issue with the latest version of oneAPI (2023.2) and the latest version of Visual Studio 2022 that causes the compilation process for the extension to fail. You can download a version of Visual Studio 2019 from the link below if you want to compile the extension from its source code.

Once downloaded, double-click the installer executable and click Continue in the popup window.

Select the Desktop development with C++ workload under Desktop & Mobile and click Install in the bottom-right corner.

Once the installation finishes, Visual Studio will launch and prompt you to sign in. We can either skip that step or exit Visual Studio completely.

Next, Visual Studio will prompt us to personalize the theme and development settings. We don’t need to use Visual Studio directly, so we stick with the defaults and exit the application once finished.

Install oneAPI Base Toolkit

With the prerequisites satisfied, we can install the oneAPI Base Toolkit. The required packages in the toolkit will take up approximately 13GB of disk space.

Download the oneAPI Toolkit Installer

We’ll install the toolkit using the offline Windows installer available at the link below:

The download page has a form to register with an email address. We can skip this by clicking the Continue without signing up text below the Sign Up & Download button.

Once downloaded, double-click the installer executable and click the Extract button in the popup window.

Once extracted, click the Continue button in the installer window.

Tick the checkbox on the next page to accept the license agreement and select the Custom Installation option. We can save a decent amount of space by only installing the necessary components.

For Intel’s PyTorch extension to function, we need the following components:

The next page allows us to integrate the oneAPI toolkit with installations of Visual Studio. We don’t need to make any changes here.

As with the Intel Computing Improvement program for the Graphics drivers, we don’t need to consent to the Intel Software Improvement program to use the toolkit. Go ahead and click the Install button to start the installation.

The installation process will likely take several minutes.

Click the Finish button in the popup window once the toolkit successfully installs and exit the installer window.

Disable Integrated Graphics

I encountered the following error when I first attempted to use Intel’s PyTorch extension on Windows:

RuntimeError: Native API failed. Native API returns: -997 (Command failed to enqueue/execute) -997 (Command failed to enqueue/execute)

I could only resolve the issue by deactivating the iGPU in the Windows Device Manager. Future versions of the extension will likely eliminate this bug, but I’ll include the steps to turn off the iGPUs for now.

Type Device Manager into the Windows search bar and click Open.

You will see the following popup message if you are not using an Administrator account. Click OK to continue.

In the Device Manager window, open the Display adapters dropdown. There should be at least two options: the Arc GPU and the iGPU included with the CPU. Double-click the iGPU listing to open its properties window.

Non-Administrator users must click the Change settings button to enable changes to the iGPU device properties.

Next, select the Driver tab and click the Disable Device button.

Click Yes in the popup window to confirm the changes.

Note

You will need to repeat this step when you install new graphics drivers in the future, assuming future releases of Intel’s extension do not resolve the issue.

Set Up a Python Environment

Now, we can create a Python environment to run the training code. We’ll install a patched version of PyTorch needed for Intel’s extension, the extension itself, and the other dependencies for the training code.

Install Mamba Package Manager

We’ll use the Mamba package manager to create the Python environment. I have a dedicated tutorial for setting up Mamba on Windows:

Open a command prompt window with the mamba environment active and navigate to a folder to store the training notebooks. For convenience, here is the command to activate the mamba environment from any command prompt window:

%USERPROFILE%\mambaforge\Scripts\activate

Create a Python Environment

Next, we’ll create a Python environment and activate it. The current version of the extension supports Python 3.11, so we’ll use that.

mamba create --name pytorch-arc python=3.11 -y
mamba activate pytorch-arc

Activate the oneAPI Environment

Run the following command to activate the oneAPI environment for this command prompt session:

call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"

We must reactivate the oneAPI environment whenever we open a new command prompt. It does not carry over from one terminal window to another.

Install Training Code Dependencies

After that, we’ll install the training code dependencies. You can learn about these dependencies (here).

pip install jupyter matplotlib pandas pillow timm torcheval torchtnt tqdm
pip install cjm_pandas_utils cjm_psl_utils cjm_pil_utils cjm_pytorch_utils

Install Prerequisite Packages

The package for Intel’s PyTorch extension requires a couple of conda packages:

mamba install pkg-config libuv -y

Install PyTorch and Intel’s PyTorch extension

The following command will install the patched version of PyTorch and the extension itself:

python -m pip install -U torch==2.0.0a0 intel_extension_for_pytorch==2.0.110+gitba7f6c1 -f https://developer.intel.com/ipex-whl-stable-xpu

Intel does not currently have a pre-compiled version of TorchVision available for Windows. We can compile a patched version of TorchVision from the source code using the instructions below:

Install Additional Dependencies for TorchVision

mamba install libpng -y

Download the Compilation Batch File

curl -o compile_bundle.bat https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/release/xpu/2.0.110/scripts/compile_bundle.bat

Run the Batch File

compile_bundle.bat "C:\Program Files (x86)\Intel\oneAPI\compiler\latest" "C:\Program Files (x86)\Intel\oneAPI\mkl\2023.2.0"

However, I did not notice any differences in performance between the standard version of TorchVision and the version I compiled from the source code The only thing of note is that pip complained when installing the patched version of PyTorch with the standard torchvision package.

Compiling PyTorch, TorchVision, and Intel’s Extension can take a long time. Therefore, I would stick with the standard version of TorchVision and ignore the warning from pip.

Set oneDNN Memory Layout

We can improve training speed by setting an environment variable for the oneDNN memory layout. The following commands will add the appropriate value for the current command line session and save it for future sessions.

set IPEX_XPU_ONEDNN_LAYOUT=1
setx IPEX_XPU_ONEDNN_LAYOUT 1

Modify PyTorch Code

It’s finally time to train a model. The Jupyter Notebooks with the original and modified training code are available on GitHub at the links below.

You can also download the notebooks to the current directory by running the following commands:

curl -o pytorch-timm-image-classifier-training-windows-without-hf-datasets.ipynb https://raw.githubusercontent.com/cj-mills/pytorch-timm-gesture-recognition-tutorial-code/main/notebooks/pytorch-timm-image-classifier-training.ipynb
curl -o intel-arc-pytorch-timm-image-classifier-training-windows-without-hf-datasets.ipynb https://raw.githubusercontent.com/cj-mills/pytorch-timm-gesture-recognition-tutorial-code/main/notebooks/intel-arc-pytorch-timm-image-classifier-training-windows.ipynb
curl -o windows_utils_without_hf.py https://github.com/cj-mills/pytorch-timm-gesture-recognition-tutorial-code/blob/main/notebooks/windows_utils_without_hf.py

Once downloaded, run the following command to launch the Jupyter Notebook Environment:

jupyter notebook

Import PyTorch Extension

We import Intel’s PyTorch extension with the following code:

import torch
import intel_extension_for_pytorch as ipex

print(f'PyTorch Version: {torch.__version__}')
print(f'Intel PyTorch Extension Version: {ipex.__version__}')

Update PyTorch Imports

We don’t want to re-import torch after the extension, so we’ll remove that line from the Import PyTorch dependencies section.

# Import PyTorch dependencies
import torch.nn as nn
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchtnt.utils import get_module_summary
from torcheval.metrics import MulticlassAccuracy
# Import PyTorch dependencies
import torch
import torch.nn as nn
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torcheval.tools import get_module_summary
from torcheval.metrics import MulticlassAccuracy

Verify Arc GPU Availability

We can double-check that the extension can use the Arc GPU by getting the properties of the available xpu devices.

def get_public_properties(obj):
    return {
        prop: getattr(obj, prop)
        for prop in dir(obj)
        if not prop.startswith("__") and not callable(getattr(obj, prop))
    }

xpu_device_count = torch.xpu.device_count()
dict_properties_list = [get_public_properties(torch.xpu.get_device_properties(i)) for i in range(xpu_device_count)]
pd.DataFrame(dict_properties_list)
dev_type gpu_eu_count max_compute_units max_num_sub_groups max_work_group_size name platform_name sub_group_sizes support_fp64 total_memory
0 gpu 512 512 128 1024 Intel(R) Arc(TM) A770 Graphics Intel(R) Level-Zero [8, 16, 32] False 16704737280

In this case, the A770 is the only device listed since we deactivated the integrated graphics on the CPU.

Update the Device Name

Next, we’ll manually set the device name to xpu.

device = 'xpu'
dtype = torch.float32
device, dtype
device = get_torch_device()
dtype = torch.float32
device, dtype

Optimize the model and optimizer Objects

Before we run the train_loop function, we’ll use Intel’s PyTorch extension to apply optimizations to the model and optimizer objects. We’ll also cast the model to the bfloat16 data type, so we can train using mixed precision. Intel’s PyTorch extension only supports the bloat16 data type for mixed-precision training currently.

# Learning rate for the model
lr = 1e-3

# Number of training epochs
epochs = 3

# AdamW optimizer; includes weight decay for regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, eps=1e-5)

# Optimize the model and optimizer objects
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)

# Learning rate scheduler; adjusts the learning rate during training
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                   max_lr=lr, 
                                                   total_steps=epochs*len(train_dataloader))

# Performance metric: Multiclass Accuracy
metric = MulticlassAccuracy()

# Check for CUDA-capable GPU availability
use_grad_scaler = torch.cuda.is_available()
# Learning rate for the model
lr = 1e-3

# Number of training epochs
epochs = 3

# AdamW optimizer; includes weight decay for regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, eps=1e-5)

# Learning rate scheduler; adjusts the learning rate during training
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                   max_lr=lr, 
                                                   total_steps=epochs*len(train_dataloader))

# Performance metric: Multiclass Accuracy
metric = MulticlassAccuracy()

# Check for CUDA-capable GPU availability
use_grad_scaler = torch.cuda.is_available()

Train the Model

That’s it for the required changes to the training code. We can now run the train_loop function. With the A770 and the dataset on an SSD, training takes between twelve and twelve and a half minutes to complete.

Update the Inference Code

Since we cast the model to bloat16, we must ensure input data use the same type. We can update the inference code using the auto-cast context manager as shown below:

# Make a prediction with the model
with torch.no_grad():
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=False):
        pred = model(img_tensor)
# Make a prediction with the model
with torch.no_grad():
    pred = model(img_tensor)

Conclusion

In this tutorial, we set up Intel’s PyTorch extension for the Windows OS and trained an image classification model using an Arc GPU. The exact setup steps may change with new versions, so check the documentation for the latest version to see if there are any changes. I’ll try to keep this tutorial updated with any significant changes to the process.


Next Steps
  • Feel free to post questions or problems related to this tutorial in the comments below. I try to make time to address them on Thursdays and Fridays.
  • If you would like to explore my services for your project, you can reach out via email at [email protected]