From 7512429ccba78dd4f1958b16794d48be81ba46b9 Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Tue, 15 Jul 2025 09:42:20 +0300 Subject: [PATCH 1/4] Add dlrm_v2 CPU FP8 QDQ example Signed-off-by: Mengni Wang --- .../dlrm_v2/fp8_quant/cpu/README.md | 37 ++ .../fp8_quant/cpu/data_process/__init__.py | 0 .../cpu/data_process/dlrm_dataloader.py | 163 +++++++ .../cpu/data_process/multi_hot_criteo.py | 340 ++++++++++++++ .../dlrm_v2/fp8_quant/cpu/dlrm_model.py | 285 ++++++++++++ .../dlrm_v2/fp8_quant/cpu/main.py | 417 ++++++++++++++++++ .../dlrm_v2/fp8_quant/cpu/requirements.txt | 6 + .../dlrm_v2/fp8_quant/cpu/setup.sh | 6 + 8 files changed, 1254 insertions(+) create mode 100644 examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md create mode 100644 examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/__init__.py create mode 100644 examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/dlrm_dataloader.py create mode 100644 examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/multi_hot_criteo.py create mode 100644 examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py create mode 100644 examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py create mode 100644 examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/requirements.txt create mode 100644 examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/setup.sh diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md new file mode 100644 index 00000000000..1a5043492c1 --- /dev/null +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md @@ -0,0 +1,37 @@ +Step-by-Step +============ + +This document describes the step-by-step instructions for FP8 quantization for [DLRM v2](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm) with IntelĀ® Neural Compressor. + + +# Prerequisite + +### 1. Environment + +```shell +bash steup.sh +pip install -r requirements.txt +``` + +### 2. Prepare Dataset + +You can download preprocessed dataset by following +https://github.com/mlcommons/inference/tree/master/recommendation/dlrm_v2/pytorch#download-preprocessed-dataset + + +### 3. Prepare pretrained model + +You can download and unzip checkpoint by following +https://github.com/mlcommons/inference/tree/master/recommendation/dlrm_v2/pytorch#downloading-model-weights + + +# Run with CPU + +```shell +TORCHINDUCTOR_FREEZING=1 python main.py --model_path /path/to/model_weights --data_path /path/to/dataset --calib --quant --accuracy +``` +or only do quantization after calibration is done +```shell +TORCHINDUCTOR_FREEZING=1 python main.py --model_path /path/to/model_weights --data_path /path/to/dataset --quant --accuracy +``` + diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/__init__.py b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/dlrm_dataloader.py b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/dlrm_dataloader.py new file mode 100644 index 00000000000..ef819f5a43d --- /dev/null +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/dlrm_dataloader.py @@ -0,0 +1,163 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from typing import List + +from torch import distributed as dist +from torch.utils.data import DataLoader +from torchrec.datasets.criteo import ( + CAT_FEATURE_COUNT, + DAYS, + DEFAULT_CAT_NAMES, + DEFAULT_INT_NAMES, + InMemoryBinaryCriteoIterDataPipe, +) +from torchrec.datasets.random import RandomRecDataset + +# OSS import +try: + # pyre-ignore[21] + # @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm/data:multi_hot_criteo + from data.multi_hot_criteo import MultiHotCriteoIterDataPipe + +except ImportError: + pass + +# internal import +try: + from .multi_hot_criteo import MultiHotCriteoIterDataPipe # noqa F811 +except ImportError: + pass + +STAGES = ["train", "val", "test"] + + +def _get_random_dataloader( + args: argparse.Namespace, + stage: str, +) -> DataLoader: + attr = f"limit_{stage}_batches" + num_batches = getattr(args, attr) + if stage in ["val", "test"] and args.test_batch_size is not None: + batch_size = args.test_batch_size + else: + batch_size = args.batch_size + return DataLoader( + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=batch_size, + hash_size=args.num_embeddings, + hash_sizes=( + args.num_embeddings_per_feature + if hasattr(args, "num_embeddings_per_feature") + else None + ), + manual_seed=args.seed if hasattr(args, "seed") else None, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + num_batches=num_batches, + ), + batch_size=None, + batch_sampler=None, + pin_memory=args.pin_memory, + num_workers=0, + ) + + +def _get_in_memory_dataloader( + args: argparse.Namespace, + stage: str, +) -> DataLoader: + dir_path = args.data_path + sparse_part = "sparse_multi_hot.npz" + datapipe = MultiHotCriteoIterDataPipe + + if stage == "train": + stage_files: List[List[str]] = [ + [os.path.join(dir_path, f"day_{i}_dense.npy") for i in range(DAYS - 1)], + [os.path.join(dir_path, f"day_{i}_{sparse_part}") for i in range(DAYS - 1)], + [os.path.join(dir_path, f"day_{i}_labels.npy") for i in range(DAYS - 1)], + ] + elif stage in ["val", "test"]: + stage_files: List[List[str]] = [ + [os.path.join(dir_path, f"day_{DAYS-1}_dense.npy")], + [os.path.join(dir_path, f"day_{DAYS-1}_{sparse_part}")], + [os.path.join(dir_path, f"day_{DAYS-1}_labels.npy")], + ] + if stage in ["val", "test"] and args.test_batch_size is not None: + batch_size = args.test_batch_size + else: + batch_size = args.batch_size + dataloader = DataLoader( + datapipe( + stage, + *stage_files, # pyre-ignore[6] + batch_size=batch_size, + rank=0, # dist.get_rank(), + world_size=1, # dist.get_world_size(), + drop_last=args.drop_last_training_batch if stage == "train" else False, + shuffle_batches=args.shuffle_batches, + shuffle_training_set=args.shuffle_training_set, + shuffle_training_set_random_seed=args.seed, + mmap_mode=args.mmap_mode, + hashes=( + args.num_embeddings_per_feature + if args.num_embeddings is None + else ([args.num_embeddings] * CAT_FEATURE_COUNT) + ), + ), + batch_size=None, + pin_memory=args.pin_memory, + collate_fn=lambda x: x, + ) + return dataloader + + +def get_dataloader(args: argparse.Namespace, backend: str, stage: str) -> DataLoader: + """ + Gets desired dataloader from dlrm_main command line options. Currently, this + function is able to return either a DataLoader wrapped around a RandomRecDataset or + a Dataloader wrapped around an InMemoryBinaryCriteoIterDataPipe. + + Args: + args (argparse.Namespace): Command line options supplied to dlrm_main.py's main + function. + backend (str): "nccl" or "gloo". + stage (str): "train", "val", or "test". + + Returns: + dataloader (DataLoader): PyTorch dataloader for the specified options. + + """ + stage = stage.lower() + if stage not in STAGES: + raise ValueError(f"Supplied stage was {stage}. Must be one of {STAGES}.") + + args.pin_memory = ( + (backend == "nccl") if not hasattr(args, "pin_memory") else args.pin_memory + ) + + return _get_in_memory_dataloader(args, stage) diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/multi_hot_criteo.py b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/multi_hot_criteo.py new file mode 100644 index 00000000000..6395329b6da --- /dev/null +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/data_process/multi_hot_criteo.py @@ -0,0 +1,340 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import zipfile +from typing import Dict, Iterator, List, Optional + +import numpy as np +import torch +from iopath.common.file_io import PathManager, PathManagerFactory +from pyre_extensions import none_throws +from torch.utils.data import IterableDataset +from torchrec.datasets.criteo import ( + CAT_FEATURE_COUNT, + DEFAULT_CAT_NAMES, +) +from torchrec.datasets.utils import Batch, PATH_MANAGER_KEY +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class MultiHotCriteoIterDataPipe(IterableDataset): + """ + Datapipe designed to operate over the MLPerf DLRM v2 synthetic multi-hot dataset. + This dataset can be created by following the steps in + torchrec_dlrm/scripts/materialize_synthetic_multihot_dataset.py. + Each rank reads only the data for the portion of the dataset it is responsible for. + + Args: + stage (str): "train", "val", or "test". + dense_paths (List[str]): List of path strings to dense npy files. + sparse_paths (List[str]): List of path strings to multi-hot sparse npz files. + labels_paths (List[str]): List of path strings to labels npy files. + batch_size (int): batch size. + rank (int): rank. + world_size (int): world size. + drop_last (Optional[bool]): Whether to drop the last batch if it is incomplete. + shuffle_batches (bool): Whether to shuffle batches + shuffle_training_set (bool): Whether to shuffle all samples in the dataset. + shuffle_training_set_random_seed (int): The random generator seed used when + shuffling the training set. + hashes (Optional[int]): List of max categorical feature value for each feature. + Length of this list should be CAT_FEATURE_COUNT. + path_manager_key (str): Path manager key used to load from different + filesystems. + + Example:: + + datapipe = MultiHotCriteoIterDataPipe( + dense_paths=["day_0_dense.npy"], + sparse_paths=["day_0_sparse_multi_hot.npz"], + labels_paths=["day_0_labels.npy"], + batch_size=1024, + rank=torch.distributed.get_rank(), + world_size=torch.distributed.get_world_size(), + ) + batch = next(iter(datapipe)) + """ + + def __init__( + self, + stage: str, + dense_paths: List[str], + sparse_paths: List[str], + labels_paths: List[str], + batch_size: int, + rank: int, + world_size: int, + drop_last: Optional[bool] = False, + shuffle_batches: bool = False, + shuffle_training_set: bool = False, + shuffle_training_set_random_seed: int = 0, + mmap_mode: bool = False, + hashes: Optional[List[int]] = None, + path_manager_key: str = PATH_MANAGER_KEY, + ) -> None: + self.stage = stage + self.dense_paths = dense_paths + self.sparse_paths = sparse_paths + self.labels_paths = labels_paths + self.batch_size = batch_size + self.rank = rank + self.world_size = world_size + self.drop_last = drop_last + self.shuffle_batches = shuffle_batches + self.shuffle_training_set = shuffle_training_set + np.random.seed(shuffle_training_set_random_seed) + self.mmap_mode = mmap_mode + # hashes are not used because they were already applied in the + # script that generates the multi-hot dataset. + self.hashes: np.ndarray = np.array(hashes).reshape((CAT_FEATURE_COUNT, 1)) + self.path_manager_key = path_manager_key + self.path_manager: PathManager = PathManagerFactory().get(path_manager_key) + + if shuffle_training_set and stage == "train": + # Currently not implemented for the materialized multi-hot dataset. + self._shuffle_and_load_data_for_rank() + else: + m = "r" if mmap_mode else None + self.dense_arrs: List[np.ndarray] = [ + np.load(f, mmap_mode=m) for f in self.dense_paths + ] + self.labels_arrs: List[np.ndarray] = [ + np.load(f, mmap_mode=m) for f in self.labels_paths + ] + self.sparse_arrs: List = [] + for sparse_path in self.sparse_paths: + multi_hot_ids_l = [] + for feat_id_num in range(CAT_FEATURE_COUNT): + multi_hot_ft_ids = self._load_from_npz( + sparse_path, f"{feat_id_num}.npy" + ) + multi_hot_ids_l.append(multi_hot_ft_ids) + self.sparse_arrs.append(multi_hot_ids_l) + len_d0 = len(self.dense_arrs[0]) + second_half_start_index = int(len_d0 // 2 + len_d0 % 2) + if stage == "val": + self.dense_arrs[0] = self.dense_arrs[0][:second_half_start_index, :] + self.labels_arrs[0] = self.labels_arrs[0][:second_half_start_index, :] + self.sparse_arrs[0] = [ + feats[:second_half_start_index, :] for feats in self.sparse_arrs[0] + ] + elif stage == "test": + self.dense_arrs[0] = self.dense_arrs[0][second_half_start_index:, :] + self.labels_arrs[0] = self.labels_arrs[0][second_half_start_index:, :] + self.sparse_arrs[0] = [ + feats[second_half_start_index:, :] for feats in self.sparse_arrs[0] + ] + # When mmap_mode is enabled, sparse features are hashed when + # samples are batched in def __iter__. Otherwise, the dataset has been + # preloaded with sparse features hashed in the preload stage, here: + # if not self.mmap_mode and self.hashes is not None: + # for k, _ in enumerate(self.sparse_arrs): + # self.sparse_arrs[k] = [ + # feat % hash + # for (feat, hash) in zip(self.sparse_arrs[k], self.hashes) + # ] + + self.num_rows_per_file: List[int] = list(map(len, self.dense_arrs)) + total_rows = sum(self.num_rows_per_file) + self.num_full_batches: int = ( + total_rows // batch_size // self.world_size * self.world_size + ) + self.last_batch_sizes: np.ndarray = np.array( + [0 for _ in range(self.world_size)] + ) + remainder = total_rows % (self.world_size * batch_size) + if not self.drop_last and 0 < remainder: + if remainder < self.world_size: + self.num_full_batches -= self.world_size + self.last_batch_sizes += batch_size + else: + self.last_batch_sizes += remainder // self.world_size + self.last_batch_sizes[: remainder % self.world_size] += 1 + + self.multi_hot_sizes: List[int] = [ + multi_hot_feat.shape[-1] for multi_hot_feat in self.sparse_arrs[0] + ] + + # These values are the same for the KeyedJaggedTensors in all batches, so they + # are computed once here. This avoids extra work from the KeyedJaggedTensor sync + # functions. + self.keys: List[str] = DEFAULT_CAT_NAMES + self.index_per_key: Dict[str, int] = { + key: i for (i, key) in enumerate(self.keys) + } + + def _load_from_npz(self, fname, npy_name): + # figure out offset of .npy in .npz + zf = zipfile.ZipFile(fname) + info = zf.NameToInfo[npy_name] + assert info.compress_type == 0 + zf.fp.seek(info.header_offset + len(info.FileHeader()) + 20) + # read .npy header + zf.open(npy_name, "r") + version = np.lib.format.read_magic(zf.fp) + shape, fortran_order, dtype = np.lib.format._read_array_header(zf.fp, version) + assert ( + dtype == "int32" + ), f"sparse multi-hot dtype is {dtype} but should be int32" + offset = zf.fp.tell() + # create memmap + return np.memmap( + zf.filename, + dtype=dtype, + shape=shape, + order="F" if fortran_order else "C", + mode="r", + offset=offset, + ) + + def _np_arrays_to_batch( + self, + dense: np.ndarray, + sparse: List[np.ndarray], + labels: np.ndarray, + ) -> Batch: + if self.shuffle_batches: + # Shuffle all 3 in unison + shuffler = np.random.permutation(len(dense)) + sparse = [multi_hot_ft[shuffler, :] for multi_hot_ft in sparse] + dense = dense[shuffler] + labels = labels[shuffler] + + batch_size = len(dense) + lengths = torch.ones((CAT_FEATURE_COUNT * batch_size), dtype=torch.int32) + for k, multi_hot_size in enumerate(self.multi_hot_sizes): + lengths[k * batch_size : (k + 1) * batch_size] = multi_hot_size + offsets = torch.cumsum(torch.concat((torch.tensor([0]), lengths)), dim=0) + length_per_key = [ + batch_size * multi_hot_size for multi_hot_size in self.multi_hot_sizes + ] + offset_per_key = torch.cumsum( + torch.concat((torch.tensor([0]), torch.tensor(length_per_key))), dim=0 + ) + values = torch.concat( + [torch.from_numpy(feat.copy()).flatten() for feat in sparse] + ) + return Batch( + dense_features=torch.from_numpy(dense.copy()), + sparse_features=KeyedJaggedTensor( + keys=self.keys, + values=values, + lengths=lengths, + offsets=offsets, + stride=batch_size, + length_per_key=length_per_key, + offset_per_key=offset_per_key.tolist(), + index_per_key=self.index_per_key, + ), + labels=torch.from_numpy(labels.reshape(-1).copy()), + ) + + def __iter__(self) -> Iterator[Batch]: + # Invariant: buffer never contains more than batch_size rows. + buffer: Optional[List[np.ndarray]] = None + + def append_to_buffer( + dense: np.ndarray, + sparse: List[np.ndarray], + labels: np.ndarray, + ) -> None: + nonlocal buffer + if buffer is None: + buffer = [dense, sparse, labels] + else: + buffer[0] = np.concatenate((buffer[0], dense)) + buffer[1] = [np.concatenate((b, s)) for b, s in zip(buffer[1], sparse)] + buffer[2] = np.concatenate((buffer[2], labels)) + + # Maintain a buffer that can contain up to batch_size rows. Fill buffer as + # much as possible on each iteration. Only return a new batch when batch_size + # rows are filled. + file_idx = 0 + row_idx = 0 + batch_idx = 0 + buffer_row_count = 0 + cur_batch_size = ( + self.batch_size if self.num_full_batches > 0 else self.last_batch_sizes[0] + ) + while ( + batch_idx + < self.num_full_batches + (self.last_batch_sizes[0] > 0) * self.world_size + ): + if buffer_row_count == cur_batch_size or file_idx == len(self.dense_arrs): + if batch_idx % self.world_size == self.rank: + yield self._np_arrays_to_batch(*none_throws(buffer)) + buffer = None + buffer_row_count = 0 + batch_idx += 1 + if 0 <= batch_idx - self.num_full_batches < self.world_size and ( + self.last_batch_sizes[0] > 0 + ): + cur_batch_size = self.last_batch_sizes[ + batch_idx - self.num_full_batches + ] + else: + rows_to_get = min( + cur_batch_size - buffer_row_count, + self.num_rows_per_file[file_idx] - row_idx, + ) + buffer_row_count += rows_to_get + slice_ = slice(row_idx, row_idx + rows_to_get) + + if batch_idx % self.world_size == self.rank: + dense_inputs = self.dense_arrs[file_idx][slice_, :] + sparse_inputs = [ + feats[slice_, :] for feats in self.sparse_arrs[file_idx] + ] + target_labels = self.labels_arrs[file_idx][slice_, :] + + # if self.mmap_mode and self.hashes is not None: + # sparse_inputs = [ + # feats % hash + # for (feats, hash) in zip(sparse_inputs, self.hashes) + # ] + + append_to_buffer( + dense_inputs, + sparse_inputs, + target_labels, + ) + row_idx += rows_to_get + + if row_idx >= self.num_rows_per_file[file_idx]: + file_idx += 1 + row_idx = 0 + + def __len__(self) -> int: + return self.num_full_batches // self.world_size + (self.last_batch_sizes[0] > 0) + + def load_batch(self, sample_list=None) -> Batch: + if sample_list is None: + sample_list = list(range(self.batch_size)) + dense = self.dense_arrs[0][sample_list, :] + sparse = [ + arr[sample_list, :] % self.hashes[i] + for i, arr in enumerate(self.sparse_arrs[0]) + ] + labels = self.labels_arrs[0][sample_list, :] + return self._np_arrays_to_batch(dense, sparse, labels) diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py new file mode 100644 index 00000000000..eed24e47ec5 --- /dev/null +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py @@ -0,0 +1,285 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch + +from torchrec.models.dlrm import SparseArch, InteractionDCNArch, DLRM_DCN +from torchrec.modules.embedding_modules import EmbeddingBagCollection + +from typing import List, Optional +import numpy as np + + +def _calculate_fan_in_and_fan_out(shape): + # numpy array version + dimensions = len(shape) + assert ( + dimensions >= 2 + ), "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + + num_input_fmaps = shape[1] + num_output_fmaps = shape[0] + receptive_field_size = 1 + if len(shape) > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def _calculate_correct_fan(shape, mode): + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError( + "Mode {} not supported, please use one of {}".format(mode, valid_modes) + ) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(shape) + return fan_in if mode == "fan_in" else fan_out + + +def calculate_gain(nonlinearity, param=None): + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return np.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return np.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return ( + 3.0 / 4 + ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def xavier_norm_(shape: tuple, gain: float = 1.0): + fan_in, fan_out = _calculate_fan_in_and_fan_out(shape) + std = gain * np.sqrt(2.0 / float(fan_in + fan_out)) + mean = 0.0 + d = np.random.normal(mean, std, size=shape).astype(np.float32) + return d + + +def kaiming_uniform_( + shape: tuple, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" +): + assert 0 not in shape, "Initializing zero-element tensors is a no-op" + fan = _calculate_correct_fan(shape, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / np.sqrt(fan) + bound = np.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return np.random.uniform(-bound, bound, size=shape) + +class _LowRankCrossNet(torch.nn.Module): + def __init__( + self, + lr_crossnet, + ) -> None: + super().__init__() + self._num_layers = lr_crossnet._num_layers + self._in_features = lr_crossnet.bias[0].shape[0] + self._low_rank = lr_crossnet._low_rank + self.V_linears = torch.nn.ModuleList() + self.W_linears = torch.nn.ModuleList() + for i in range(self._num_layers): + self.V_linears.append( + torch.nn.Linear(self._in_features, self._low_rank, bias=False) + ) + self.W_linears.append( + torch.nn.Linear(self._low_rank, self._in_features, bias=True) + ) + self.V_linears[i].weight.data = lr_crossnet.V_kernels[i] + self.W_linears[i].weight.data = lr_crossnet.W_kernels[i] + self.W_linears[i].bias.data = lr_crossnet.bias[i] + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x_0 = input + x_l = x_0 + for layer in range(self._num_layers): + x_l_v = self.V_linears[layer](x_l) + x_l_w = self.W_linears[layer](x_l_v) + x_l = x_0 * x_l_w + x_l # (B, N) + return x_l + + +def replace_crossnet(dlrm): + crossnet = dlrm.inter_arch.crossnet + new_crossnet = _LowRankCrossNet(crossnet) + dlrm.inter_arch.crossnet = new_crossnet + del crossnet + +class SparseArchCatDense(SparseArch): + def forward( + self, + embedded_dense_features, + sparse_features, + ) -> torch.Tensor: + """ + Args: + embedded_dense_features: the output of DenseArch. + sparse_features: the indices/offsets for F embeddingbags in embedding_bag_collection + + Returns: + torch.Tensor: tensor of shape B X (F + 1) X D. + """ + (B, _) = embedded_dense_features.shape + embedding_bag_collection = self.embedding_bag_collection + indices = tuple([sf["values"] for _, sf in sparse_features.items()]) + offsets = tuple([sf["offsets"] for _, sf in sparse_features.items()]) + embedded_sparse_features: List[torch.Tensor] = [] + for i, embedding_bag in enumerate( + embedding_bag_collection.embedding_bags.values() + ): + for feature_name in embedding_bag_collection._feature_names[i]: + f = sparse_features[feature_name] + res = embedding_bag( + f["values"], + f["offsets"], + per_sample_weights=None, + ) + embedded_sparse_features.append(res) + to_cat = [embedded_dense_features] + list(embedded_sparse_features) + out = torch.cat(to_cat, dim=1) + return out + + +class InteractionDCNArchWithoutCat(InteractionDCNArch): + def forward(self, concat_dense_sparse: torch.Tensor) -> torch.Tensor: + """ + Args: + concat_dense_sparse (torch.Tensor): an input tensor of size B X (F*D + D). + + Returns: + torch.Tensor: an output tensor of size B X (F*D + D). + """ + + # size B X (F * D + D) + return self.crossnet(concat_dense_sparse) + + +class IPEX_DLRM_DCN(DLRM_DCN): + """ + Recsys model with DCN modified from the original model from "Deep Learning Recommendation + Model for Personalization and Recommendation Systems" + (https://arxiv.org/abs/1906.00091). Similar to DLRM module but has + DeepCrossNet https://arxiv.org/pdf/2008.13535.pdf as the interaction layer. + + The module assumes all sparse features have the same embedding dimension + (i.e. each EmbeddingBagConfig uses the same embedding_dim). + + The following notation is used throughout the documentation for the models: + + * F: number of sparse features + * D: embedding_dimension of sparse features + * B: batch size + * num_features: number of dense features + + Args: + embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags + used to define `SparseArch`. + dense_in_features (int): the dimensionality of the dense input features. + dense_arch_layer_sizes (List[int]): the layer sizes for the `DenseArch`. + over_arch_layer_sizes (List[int]): the layer sizes for the `OverArch`. + The output dimension of the `InteractionArch` should not be manually + specified here. + dcn_num_layers (int): the number of DCN layers in the interaction. + dcn_low_rank_dim (int): the dimensionality of low rank approximation + used in the dcn layers. + dense_device (Optional[torch.device]): default compute device. + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + dense_in_features: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + dcn_num_layers: int, + dcn_low_rank_dim: int, + dense_device: Optional[torch.device] = None, + ) -> None: + # initialize DLRM + # sparse arch and dense arch are initialized via DLRM + super().__init__( + embedding_bag_collection, + dense_in_features, + dense_arch_layer_sizes, + over_arch_layer_sizes, + dcn_num_layers, + dcn_low_rank_dim, + dense_device, + ) + + num_sparse_features: int = len(self.sparse_arch.sparse_feature_names) + + embedding_bag_collection = self.sparse_arch.embedding_bag_collection + + self.sparse_arch = SparseArchCatDense(embedding_bag_collection) + + crossnet = self.inter_arch.crossnet + self.inter_arch = InteractionDCNArchWithoutCat( + num_sparse_features=num_sparse_features, + crossnet=crossnet, + ) + + def forward( + self, + dense_features: torch.Tensor, + sparse_features, + ) -> torch.Tensor: + """ + Args: + dense_features (torch.Tensor): the dense features. + sparse_features (KeyedJaggedTensor): the sparse features. + + Returns: + torch.Tensor: logits. + """ + embedded_dense = self.dense_arch(dense_features) + concat_sparse_dense = self.sparse_arch(embedded_dense, sparse_features) + concatenated_dense = self.inter_arch(concat_sparse_dense) + logits = self.over_arch(concatenated_dense) + return logits diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py new file mode 100644 index 00000000000..fe9ec21d4d7 --- /dev/null +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py @@ -0,0 +1,417 @@ +import argparse +import itertools +import numpy as np +import sys +from torch.profiler import record_function +from pprint import pprint +from typing import List +import time + +import torch + +import torchmetrics as metrics +from pyre_extensions import none_throws +from torch.utils.data import DataLoader +from torchrec import EmbeddingBagCollection +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.models.dlrm import DLRMTrain +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from tqdm import tqdm +from neural_compressor.torch.quantization import ( + FP8Config, + convert, + finalize_calibration, + prepare, +) +from dlrm_model import IPEX_DLRM_DCN, replace_crossnet +from data_process.dlrm_dataloader import get_dataloader + + +TRAIN_PIPELINE_STAGES = 3 # Number of stages in TrainPipelineSparseDist. + + +def unpack(input: KeyedJaggedTensor) -> dict: + output = {} + for k, v in input.to_dict().items(): + output[k] = {} + output[k]["values"] = v._values.int() + output[k]["offsets"] = v._offsets.int() + return output + + +def load_snapshot(model, model_path): + from torchsnapshot import Snapshot + + snapshot = Snapshot(path=model_path) + snapshot.restore(app_state={"model": model}) + + +def fetch_batch(dataloader): + try: + batch = dataloader.dataset.load_batch() + except: + import torchrec + + dataset = dataloader.source.dataset + if isinstance( + dataset, torchrec.datasets.criteo.InMemoryBinaryCriteoIterDataPipe + ): + sample_list = list(range(dataset.batch_size)) + dense = dataset.dense_arrs[0][sample_list, :] + sparse = [arr[sample_list, :] for arr in dataset.sparse_arrs][ + 0 + ] % dataset.hashes + labels = dataset.labels_arrs[0][sample_list, :] + return dataloader.func(dataset._np_arrays_to_batch(dense, sparse, labels)) + batch = dataloader.func( + dataloader.source.dataset.batch_generator._generate_batch() + ) + return batch + + +class DLRM_DataLoader(object): + def __init__(self, loader=None): + self.loader = loader + self.batch_size = 1 + def __iter__(self): + for dense_feature, sparse_dfeature in [(self.loader.dense_features, self.loader.sparse_features)]: + yield {"dense_features": dense_feature, "sparse_features": sparse_dfeature} + + +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="torchrec dlrm example trainer") + parser.add_argument( + "--epochs", + type=int, + default=1, + help="number of epochs to train", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="batch size to use for training", + ) + parser.add_argument( + "--drop_last_training_batch", + dest="drop_last_training_batch", + action="store_true", + help="Drop the last non-full training batch", + ) + parser.add_argument( + "--limit_val_batches", + type=int, + default=100, + help="number of validation batches", + ) + parser.add_argument( + "--warmup_batches", + type=int, + default=100, + help="number of test batches", + ) + parser.add_argument( + "--num_embeddings", + type=int, + default=100_000, + help="max_ind_size. The number of embeddings in each embedding table. Defaults" + " to 100_000 if num_embeddings_per_feature is not supplied.", + ) + parser.add_argument( + "--num_embeddings_per_feature", + type=str, + default="40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36", + help="Comma separated max_ind_size per sparse feature. The number of embeddings" + " in each embedding table. 26 values are expected for the Criteo dataset.", + ) + parser.add_argument( + "--dense_arch_layer_sizes", + type=str, + default="512,256,128", + help="Comma separated layer sizes for dense arch.", + ) + parser.add_argument( + "--over_arch_layer_sizes", + type=str, + default="1024,1024,512,256,1", + help="Comma separated layer sizes for over arch.", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=128, + help="Size of each embedding.", + ) + parser.add_argument( + "--dcn_num_layers", + type=int, + default=3, + help="Number of DCN layers in interaction layer (only on dlrm with DCN).", + ) + parser.add_argument( + "--dcn_low_rank_dim", + type=int, + default=512, + help="Low rank dimension for DCN in interaction layer (only on dlrm with DCN).", + ) + parser.add_argument( + "--seed", + type=int, + default=1696543516, + help="Random seed for reproducibility.", + ) + parser.add_argument( + "--data_path", + type=str, + default=None, + help="Directory path containing the MLPerf v2 synthetic multi-hot dataset npz files.", + ) + parser.add_argument( + "--model_path", + type=str, + ) + parser.add_argument( + "--inductor", + action="store_true", + help="whether use torch.compile()", + ) + parser.add_argument( + "--pin_memory", + dest="pin_memory", + action="store_true", + help="Use pinned memory when loading data.", + ) + parser.add_argument( + "--mmap_mode", + dest="mmap_mode", + action="store_true", + help="--mmap_mode mmaps the dataset." + " That is, the dataset is kept on disk but is accessed as if it were in memory." + " --mmap_mode is intended mostly for faster debugging. Use --mmap_mode to bypass" + " preloading the dataset when preloading takes too long or when there is " + " insufficient memory available to load the full dataset.", + ) + parser.add_argument( + "--shuffle_batches", + dest="shuffle_batches", + action="store_true", + help="Shuffle each batch during training.", + ) + parser.add_argument( + "--shuffle_training_set", + dest="shuffle_training_set", + action="store_true", + help="Shuffle the training set in memory. This will override mmap_mode", + ) + parser.add_argument( + "--test_batch_size", + type=int, + default=None, + help="batch size to use for validation and testing", + ) + + parser.add_argument("--calib", action="store_true") + parser.add_argument("--accuracy", action="store_true") + parser.add_argument("--quant", action="store_true") + parser.add_argument("--out_dir", type=str, default="inc_fp8/measure", help="A folder to save calibration result") + return parser.parse_args(argv) + + +def _evaluate( + eval_model, + eval_dataloader: DataLoader, + stage: str, + args, +) -> float: + """ + Evaluates model. Computes and prints AUROC + + Args: + model (torch.nn.Module): model for evaluation. + eval_dataloader (DataLoader): Dataloader for either the validation set or test set. + stage (str): "val" or "test". + args (argparse.Namespace): parsed command line args. + + Returns: + float: auroc result + """ + limit_batches = args.limit_val_batches + + benckmark_batch = fetch_batch(eval_dataloader) + benckmark_batch.sparse_features = unpack(benckmark_batch.sparse_features) + def fetch_next(iterator, current_it): + with record_function("generate batch"): + next_batch = next(iterator) + with record_function("unpack KeyJaggedTensor"): + next_batch.sparse_features = unpack(next_batch.sparse_features) + return next_batch + + def eval_step(model, iterator, current_it): + batch = fetch_next(iterator, current_it) + with record_function("model forward"): + t1 = time.time() + logits = model(batch.dense_features, batch.sparse_features) + t2 = time.time() + return logits, batch.labels, t2 - t1 + + pbar = tqdm( + iter(int, 1), + desc=f"Evaluating {stage} set", + total=len(eval_dataloader), + disable=True, + ) + + eval_model.eval() + device = torch.device("cpu") + + iterator = itertools.islice(iter(eval_dataloader), limit_batches) + # Two filler batches are appended to the end of the iterator to keep the pipeline active while the + # last two remaining batches are still in progress awaiting results. + two_filler_batches = itertools.islice( + iter(eval_dataloader), TRAIN_PIPELINE_STAGES - 1 + ) + iterator = itertools.chain(iterator, two_filler_batches) + + preds = [] + labels = [] + + auroc_computer = metrics.AUROC(task="binary").to(device) + + total_t = 0 + it = 0 + ctx1 = torch.no_grad() + ctx2 = torch.autocast("cpu", enabled=True, dtype=torch.bfloat16) + with ctx1, ctx2: + while True: + try: + logits, label, fw_t = eval_step(eval_model, iterator, it) + if it > args.warmup_batches: + total_t += fw_t + pred = torch.sigmoid(logits) + preds.append(pred) + labels.append(label) + pbar.update(1) + it += 1 + except StopIteration: + # Dataset traversal complete + break + + preds = torch.cat(preds) + labels = torch.cat(labels) + + num_samples = labels.shape[0] - args.warmup_batches * args.batch_size + auroc = auroc = auroc_computer(preds.squeeze().float(), labels.float()) + print(f"AUROC over {stage} set: {auroc}.") + print(f"Number of {stage} samples: {num_samples}") + print(f"Throughput: {num_samples/total_t} fps") + print(f"Final AUROC: {auroc} ") + return auroc + + +def construct_model(args): + device: torch.device = torch.device("cpu") + eb_configs = [ + EmbeddingBagConfig( + name=f"t_{feature_name}", + embedding_dim=args.embedding_dim, + num_embeddings=( + none_throws(args.num_embeddings_per_feature)[feature_idx] + if args.num_embeddings is None + else args.num_embeddings + ), + feature_names=[feature_name], + ) + for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) + ] + + dcn_init_fn = IPEX_DLRM_DCN + dlrm_model = dcn_init_fn( + embedding_bag_collection=EmbeddingBagCollection( + tables=eb_configs, device=torch.device("cpu") + ), + dense_in_features=len(DEFAULT_INT_NAMES), + dense_arch_layer_sizes=args.dense_arch_layer_sizes, + over_arch_layer_sizes=args.over_arch_layer_sizes, + dcn_num_layers=args.dcn_num_layers, + dcn_low_rank_dim=args.dcn_low_rank_dim, + dense_device=device, + ) + + train_model = DLRMTrain(dlrm_model) + assert args.model_path + load_snapshot(train_model, args.model_path) + + replace_crossnet(train_model.model) + return train_model + + +def main(argv: List[str]) -> None: + """ + Args: + argv (List[str]): command line args. + + Returns: + None. + """ + + args = parse_args(argv) + for name, val in vars(args).items(): + try: + vars(args)[name] = list(map(int, val.split(","))) + except (ValueError, AttributeError): + pass + + backend = "gloo" + pprint(vars(args)) + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + if args.num_embeddings_per_feature is not None: + args.num_embeddings = None + + sharded_module_kwargs = {} + if args.over_arch_layer_sizes is not None: + sharded_module_kwargs["over_arch_layer_sizes"] = args.over_arch_layer_sizes + + model = construct_model(args) + model.model.sparse_arch = model.model.sparse_arch.bfloat16() + + qconfig = FP8Config( + fp8_config="E4M3", + use_qdq=True, + scale_method="MAXABS_ARBITRARY", + dump_stats_path=args.out_dir, + ) + + if args.calib: + test_dataloader = get_dataloader(args, backend, "test") + model.model = prepare(model.model, qconfig) + + batch = fetch_batch(test_dataloader) + batch.sparse_features = unpack(batch.sparse_features) + batch_idx = list(range(128000)) + batch = test_dataloader.dataset.load_batch(batch_idx) + batch.sparse_features = unpack(batch.sparse_features) + model.model(batch.dense_features, batch.sparse_features) + + finalize_calibration(model.model) + + if args.quant: + model.model = convert(model.model, qconfig) + + if args.accuracy: + val_dataloader = get_dataloader(args, backend, "val") + model = torch.compile(model) + + _evaluate( + model.model, + val_dataloader, + "test", + args, + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/requirements.txt b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/requirements.txt new file mode 100644 index 00000000000..5ba1ae9c06f --- /dev/null +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/requirements.txt @@ -0,0 +1,6 @@ +numpy +neural-compressor-pt +torchmetrics +pyre_extensions +torchsnapshot +tqdm diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/setup.sh b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/setup.sh new file mode 100644 index 00000000000..6067d2a1bdf --- /dev/null +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/setup.sh @@ -0,0 +1,6 @@ +pip install torchrec --index-url https://download.pytorch.org/whl/cpu --no-deps +pip install torchmetrics==1.0.3 +pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cpu +pip install --pre torchao==0.12.0.dev20250702+cpu --index-url https://download.pytorch.org/whl/nightly/cpu +pip install torch --index-url https://download.pytorch.org/whl/cpu +pip install torchvision --index-url https://download.pytorch.org/whl/cpu \ No newline at end of file From ffe6850e178bc0b746086889bd5449a6c1699061 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Fri, 18 Jul 2025 13:33:09 +0800 Subject: [PATCH 2/4] Update README.md --- .../pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md index 1a5043492c1..f9b91dfb6d6 100644 --- a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/README.md @@ -9,7 +9,7 @@ This document describes the step-by-step instructions for FP8 quantization for [ ### 1. Environment ```shell -bash steup.sh +bash setup.sh pip install -r requirements.txt ``` From 8819171ab5d19dc98580df173e617c25e88a25f7 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Thu, 24 Jul 2025 14:57:59 +0800 Subject: [PATCH 3/4] Update dlrm_model.py --- .../pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py index eed24e47ec5..c9c46769635 100644 --- a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/dlrm_model.py @@ -200,7 +200,7 @@ def forward(self, concat_dense_sparse: torch.Tensor) -> torch.Tensor: return self.crossnet(concat_dense_sparse) -class IPEX_DLRM_DCN(DLRM_DCN): +class OPTIMIZED_DLRM_DCN(DLRM_DCN): """ Recsys model with DCN modified from the original model from "Deep Learning Recommendation Model for Personalization and Recommendation Systems" From d3d7dec5572f834ca17685f2e587ab9c6214dcb8 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Thu, 24 Jul 2025 14:58:46 +0800 Subject: [PATCH 4/4] Update main.py --- .../pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py index fe9ec21d4d7..5d31eb3d133 100644 --- a/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py +++ b/examples/3.x_api/pytorch/recommendation/dlrm_v2/fp8_quant/cpu/main.py @@ -24,7 +24,7 @@ finalize_calibration, prepare, ) -from dlrm_model import IPEX_DLRM_DCN, replace_crossnet +from dlrm_model import OPTIMIZED_DLRM_DCN, replace_crossnet from data_process.dlrm_dataloader import get_dataloader @@ -325,7 +325,7 @@ def construct_model(args): for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) ] - dcn_init_fn = IPEX_DLRM_DCN + dcn_init_fn = OPTIMIZED_DLRM_DCN dlrm_model = dcn_init_fn( embedding_bag_collection=EmbeddingBagCollection( tables=eb_configs, device=torch.device("cpu")