From f162e104d49ac5c844075c248f7df4927376aef3 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Tue, 25 May 2021 16:21:42 +0800 Subject: [PATCH 01/20] [feat] add shared util functions --- EduKTM/utils/__init__.py | 6 ++++ EduKTM/utils/loss.py | 71 +++++++++++++++++++++++++++++++++++++ EduKTM/utils/torch_utils.py | 47 ++++++++++++++++++++++++ EduKTM/utils/utils.py | 37 +++++++++++++++++++ 4 files changed, 161 insertions(+) create mode 100644 EduKTM/utils/__init__.py create mode 100644 EduKTM/utils/loss.py create mode 100644 EduKTM/utils/torch_utils.py create mode 100644 EduKTM/utils/utils.py diff --git a/EduKTM/utils/__init__.py b/EduKTM/utils/__init__.py new file mode 100644 index 0000000..c11a047 --- /dev/null +++ b/EduKTM/utils/__init__.py @@ -0,0 +1,6 @@ +# coding: utf-8 +# 2021/5/24 @ tongshiwei + +from .utils import * +from .loss import SequenceLogisticMaskLoss as SLMLoss +from .torch_utils import * diff --git a/EduKTM/utils/loss.py b/EduKTM/utils/loss.py new file mode 100644 index 0000000..e229dae --- /dev/null +++ b/EduKTM/utils/loss.py @@ -0,0 +1,71 @@ +# coding: utf-8 +# 2021/5/24 @ tongshiwei +__all__ = ["SequenceLogisticMaskLoss", "LogisticMaskLoss"] + +import torch +from torch import nn + +from .torch_utils import pick, sequence_mask + + +class SequenceLogisticMaskLoss(nn.Module): + """ + Notes + ----- + The loss has been average, so when call the step method of trainer, batch_size should be 1 + """ + + def __init__(self, lr=0.0, lw1=0.0, lw2=0.0): + super(SequenceLogisticMaskLoss, self).__init__() + self.lr = lr + self.lw1 = lw1 + self.lw2 = lw2 + self.loss = torch.nn.BCELoss(reduction='none') + + def forward(self, pred_rs, pick_index, label, label_mask): + if self.lw1 > 0.0 or self.lw2 > 0.0: + post_pred_rs = pred_rs[:, 1:] + pre_pred_rs = pred_rs[:, :-1] + diff = post_pred_rs - pre_pred_rs + diff = sequence_mask(diff, label_mask) + w1 = torch.mean(torch.norm(diff, 1, -1)) / diff.shape[-1] + w2 = torch.mean(torch.norm(diff, 2, -1)) / diff.shape[-1] + # w2 = F.mean(F.sqrt(diff ** 2)) + w1 = w1 * self.lw1 if self.lw1 > 0.0 else 0.0 + w2 = w2 * self.lw2 if self.lw2 > 0.0 else 0.0 + else: + w1 = 0.0 + w2 = 0.0 + + if self.lr > 0.0: + re_pred_rs = pred_rs[:, 1:] + re_pred_rs = pick(re_pred_rs, pick_index) + wr = sequence_mask(self.loss(re_pred_rs, label.float()), label_mask) + wr = torch.mean(wr) * self.lr + else: + wr = 0.0 + + pred_rs = pred_rs[:, 1:] + pred_rs = pick(pred_rs, pick_index) + loss = sequence_mask(self.loss(pred_rs, label.float()), label_mask) + # loss = F.sum(loss, axis=-1) + loss = torch.mean(loss) + w1 + w2 + wr + return loss + + +class LogisticMaskLoss(nn.Module): + """ + Notes + ----- + The loss has been average, so when call the step method of trainer, batch_size should be 1 + """ + + def __init__(self): + super(LogisticMaskLoss, self).__init__() + + self.loss = torch.nn.BCELoss() + + def forward(self, pred_rs, label, label_mask, *args, **kwargs): + loss = sequence_mask(self.loss(pred_rs, label), label_mask) + loss = torch.mean(loss) + return loss diff --git a/EduKTM/utils/torch_utils.py b/EduKTM/utils/torch_utils.py new file mode 100644 index 0000000..1ae0648 --- /dev/null +++ b/EduKTM/utils/torch_utils.py @@ -0,0 +1,47 @@ +# coding: utf-8 +# 2021/5/24 @ tongshiwei +__all__ = ["pick", "tensor2list", "length2mask", "get_sequence_mask", "sequence_mask"] + +import torch +from torch import Tensor + + +def pick(tensor, index, axis=-1): + return torch.gather(tensor, axis, index.unsqueeze(axis)).squeeze(axis) + + +def tensor2list(tensor: Tensor): + return tensor.cpu().tolist() + + +def length2mask(length, max_len, valid_mask_val, invalid_mask_val): + mask = [] + + if isinstance(valid_mask_val, Tensor): + valid_mask_val = tensor2list(valid_mask_val) + if isinstance(invalid_mask_val, Tensor): + invalid_mask_val = tensor2list(invalid_mask_val) + if isinstance(length, Tensor): + length = tensor2list(length) + + for _len in length: + mask.append([valid_mask_val] * _len + [invalid_mask_val] * (max_len - _len)) + + return torch.tensor(mask) + + +def get_sequence_mask(shape, sequence_length, axis=1): + assert axis <= len(shape) + mask_shape = shape[axis + 1:] + + valid_mask_val = torch.ones(mask_shape) + invalid_mask_val = torch.zeros(mask_shape) + + max_len = shape[axis] + + return length2mask(sequence_length, max_len, valid_mask_val, invalid_mask_val) + + +def sequence_mask(tensor: Tensor, sequence_length, axis=1): + mask = get_sequence_mask(tensor.shape, sequence_length, axis).to(tensor.device) + return tensor * mask diff --git a/EduKTM/utils/utils.py b/EduKTM/utils/utils.py new file mode 100644 index 0000000..7f50820 --- /dev/null +++ b/EduKTM/utils/utils.py @@ -0,0 +1,37 @@ +# coding: utf-8 +# 2021/5/24 @ tongshiwei + + +__all__ = ["as_list"] + + +def as_list(obj) -> list: + r"""A utility function that converts the argument to a list + if it is not already. + + Parameters + ---------- + obj : object + argument to be converted to a list + + Returns + ------- + list_obj: list + If `obj` is a list or tuple, return it. Otherwise, + return `[obj]` as a single-element list. + + Examples + -------- + >>> as_list(1) + [1] + >>> as_list([1]) + [1] + >>> as_list((1, 2)) + [1, 2] + """ + if isinstance(obj, list): + return obj + elif isinstance(obj, tuple): + return list(obj) + else: + return [obj] From 5de840f186844bfa9162933035139a1950dc3535 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 10:15:14 +0800 Subject: [PATCH 02/20] [feat] add annotations --- EduKTM/utils/loss.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/EduKTM/utils/loss.py b/EduKTM/utils/loss.py index e229dae..2687388 100644 --- a/EduKTM/utils/loss.py +++ b/EduKTM/utils/loss.py @@ -16,6 +16,14 @@ class SequenceLogisticMaskLoss(nn.Module): """ def __init__(self, lr=0.0, lw1=0.0, lw2=0.0): + """ + + Parameters + ---------- + lr: reconstruction + lw1 + lw2 + """ super(SequenceLogisticMaskLoss, self).__init__() self.lr = lr self.lw1 = lw1 From 5d1023fe5f81c75b9262f30821d5e6af9f390272 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 10:15:34 +0800 Subject: [PATCH 03/20] [feat] add external library --- EduKTM/utils/torch_utils/extlib/__init__.py | 5 + EduKTM/utils/torch_utils/extlib/data.py | 48 ++ EduKTM/utils/torch_utils/extlib/sampler.py | 476 ++++++++++++++++++++ 3 files changed, 529 insertions(+) create mode 100644 EduKTM/utils/torch_utils/extlib/__init__.py create mode 100644 EduKTM/utils/torch_utils/extlib/data.py create mode 100644 EduKTM/utils/torch_utils/extlib/sampler.py diff --git a/EduKTM/utils/torch_utils/extlib/__init__.py b/EduKTM/utils/torch_utils/extlib/__init__.py new file mode 100644 index 0000000..178d8b3 --- /dev/null +++ b/EduKTM/utils/torch_utils/extlib/__init__.py @@ -0,0 +1,5 @@ +# coding: utf-8 +# 2021/5/26 @ tongshiwei + +from .data import * +from .sampler import * diff --git a/EduKTM/utils/torch_utils/extlib/data.py b/EduKTM/utils/torch_utils/extlib/data.py new file mode 100644 index 0000000..de86ce2 --- /dev/null +++ b/EduKTM/utils/torch_utils/extlib/data.py @@ -0,0 +1,48 @@ +# coding: utf-8 +# 2021/5/25 @ tongshiwei +# These codes are modified from gluonnlp + +__all__ = ["PadSequence"] + + +class PadSequence: + """Pad the sequence. + + Pad the sequence to the given `length` by inserting `pad_val`. If `clip` is set, + sequence that has length larger than `length` will be clipped. + + Parameters + ---------- + length : int + The maximum length to pad/clip the sequence + pad_val : number + The pad value. Default 0 + clip : bool + """ + + def __init__(self, length, pad_val=0, clip=True): + self._length = length + self._pad_val = pad_val + self._clip = clip + + def __call__(self, sample): + """ + + Parameters + ---------- + sample : list of number or mx.nd.NDArray or np.ndarray + + Returns + ------- + ret : list of number or mx.nd.NDArray or np.ndarray + """ + sample_length = len(sample) + if sample_length >= self._length: + if self._clip and sample_length > self._length: + return sample[:self._length] + else: + return sample + else: + return sample + [ + self._pad_val for _ in range(self._length - sample_length) + ] diff --git a/EduKTM/utils/torch_utils/extlib/sampler.py b/EduKTM/utils/torch_utils/extlib/sampler.py new file mode 100644 index 0000000..9470a95 --- /dev/null +++ b/EduKTM/utils/torch_utils/extlib/sampler.py @@ -0,0 +1,476 @@ +# coding: utf-8 +# 2021/5/26 @ tongshiwei +# These codes are modified from gluonnlp + +import math +import numpy as np +import warnings + + +def _match_bucket_keys(bucket_keys, seq_lengths): + bucket_key_npy = np.array(bucket_keys, dtype=np.int32) + bucket_sample_ids = [list() for _ in range(len(bucket_keys))] + batch_size = 10000 + bucket_key_npy = bucket_key_npy.reshape((1,) + bucket_key_npy.shape) + for begin in range(0, len(seq_lengths), batch_size): + end = min(begin + batch_size, len(seq_lengths)) + diff = bucket_key_npy - np.expand_dims(seq_lengths[begin:end], axis=1) + if diff.ndim == 3: + is_valid_bucket = np.prod(diff >= 0, axis=2) + pad_val = np.sum(diff, axis=2) + else: + is_valid_bucket = diff >= 0 + pad_val = diff + seq_ids_not_found = np.nonzero(is_valid_bucket.sum(axis=1) == 0)[0].tolist() + masked_pad_val = np.ma.array(pad_val, mask=1 - is_valid_bucket) + batch_bucket_id = masked_pad_val.argmin(axis=1).tolist() + if len(seq_ids_not_found) > 0: + raise ValueError('Find elements in seq_lengths that cannot fit in the ' + 'given buckets, seq_length=%s, bucket_keys=%s. ' \ + 'You must increase the bucket size.' + % (str(seq_lengths[seq_ids_not_found]), str(bucket_keys))) + for i, bucket_id in enumerate(batch_bucket_id): + bucket_sample_ids[bucket_id].append(i + begin) + return bucket_sample_ids + + +def _bucket_stats(bucket_sample_ids, seq_lengths): + bucket_average_lengths = [] + bucket_length_stds = [] + for sample_ids in bucket_sample_ids: + if len(sample_ids) > 0: + lengths = seq_lengths[sample_ids] + bucket_average_lengths.append(np.mean(lengths)) + bucket_length_stds.append(np.std(lengths)) + else: + bucket_average_lengths.append(0) + bucket_length_stds.append(0) + return (bucket_average_lengths, bucket_length_stds) + + +class BucketScheme: + r"""Base class for generating bucket keys.""" + + def __call__(self, max_lengths, min_lengths, num_buckets): + """Generate bucket keys based on the lengths of sequences and number of buckets. + + Parameters + ---------- + max_lengths : int or list of int + Maximum of lengths of sequences. + min_lengths : int or list of int + Minimum of lengths of sequences. + num_buckets : int + Number of buckets + + Returns + ------- + bucket_keys : list of int + A list including the keys of the buckets. + """ + raise NotImplementedError + + +class ConstWidthBucket(BucketScheme): + r"""Buckets with constant width.""" + + def __call__(self, max_lengths, min_lengths, num_buckets): + r"""This generate bucket keys given that all the buckets have the same width. + + Parameters + ---------- + max_lengths : int or list of int + Maximum of lengths of sequences. + min_lengths : int or list of int + Minimum of lengths of sequences. + num_buckets : int + Number of buckets + + Returns + ------- + bucket_keys : list of int + A list including the keys of the buckets. + """ + if isinstance(max_lengths, list): + bucket_width_l = [max((1 + max_len - min_len) // num_buckets, 1) + for max_len, min_len in + zip(max_lengths, min_lengths)] + bucket_keys = \ + [tuple(max(max_len - i * width, min_len) for max_len, min_len, width in + zip(max_lengths, min_lengths, bucket_width_l)) + for i in range(num_buckets)] + else: + bucket_width = max((1 + max_lengths - min_lengths) // num_buckets, 1) + bucket_keys = [max(max_lengths - i * bucket_width, min_lengths) + for i in range(num_buckets)] + return bucket_keys + + +class LinearWidthBucket(BucketScheme): + r""" Buckets with linearly increasing width: + :math:`w_i = \alpha * i + 1` for all :math:`i \geq 1`. + """ + + def __call__(self, max_lengths, min_lengths, num_buckets): + r"""This function generates bucket keys with linearly increasing bucket width: + + Parameters + ---------- + max_lengths : int or list of int + Maximum of lengths of sequences. + min_lengths : int or list of int + Minimum of lengths of sequences. + num_buckets : int + Number of buckets + + Returns + ------- + bucket_keys : list of int + A list including the keys of the buckets. + """ + if isinstance(max_lengths, list): + alpha_l = [2 * float(max_len - min_len - num_buckets) + / (num_buckets * (num_buckets + 1)) + for max_len, min_len in + zip(max_lengths, min_lengths)] + bucket_keys = \ + [tuple(int(round(min_len + alpha * (((i + 1) * (i + 2)) / 2) + i + 1)) + for min_len, alpha in zip(min_lengths, alpha_l)) + for i in range(num_buckets)] + bucket_keys[-1] = tuple(max(max_bucket_key, max_len) + for max_bucket_key, max_len + in zip(bucket_keys[-1], max_lengths)) + else: + alpha = 2 * float(max_lengths - min_lengths - num_buckets) \ + / (num_buckets * (num_buckets + 1)) + bucket_keys = [int(round(min_lengths + alpha * (((i + 1) * (i + 2)) / 2) + i + 1)) + for i in range(num_buckets)] + bucket_keys[-1] = max(bucket_keys[-1], max_lengths) + return bucket_keys + + +class ExpWidthBucket(BucketScheme): + r""" Buckets with exponentially increasing width: + :math:`w_i = bucket\_len\_step * w_{i-1}` for all :math:`i \geq 2`. + + Parameters + ---------- + bucket_len_step : float, default 1.1 + This is the increasing factor for the bucket width. + """ + + def __init__(self, bucket_len_step=1.1): + self.bucket_len_step = bucket_len_step + + def __call__(self, max_lengths, min_lengths, num_buckets): + r"""This function generates bucket keys exponentially increasing bucket width. + + Parameters + ---------- + max_lengths : int or list of int + Maximum of lengths of sequences. + min_lengths : int or list of int + Minimum of lengths of sequences. + num_buckets : int + Number of buckets + + Returns + ------- + bucket_keys : list of int + A list including the keys of the buckets. + """ + if isinstance(max_lengths, list): + initial_width_l = [ + (max_len - min_len) * (self.bucket_len_step - 1) + / (math.pow(self.bucket_len_step, num_buckets) - 1) + for max_len, min_len in + zip(max_lengths, min_lengths)] + bucket_keys = \ + [tuple( + int(round(min_len + initial_width * (math.pow(self.bucket_len_step, i + 1) - 1) + / (self.bucket_len_step - 1))) + for min_len, initial_width in zip(min_lengths, initial_width_l)) + for i in range(num_buckets)] + bucket_keys[-1] = tuple(max(max_bucket_key, max_len) + for max_bucket_key, max_len + in zip(bucket_keys[-1], max_lengths)) + else: + initial_width = (max_lengths - min_lengths) * (self.bucket_len_step - 1) \ + / (math.pow(self.bucket_len_step, num_buckets) - 1) + bucket_keys = [ + int(round(min_lengths + initial_width * (math.pow(self.bucket_len_step, i + 1) - 1) + / (self.bucket_len_step - 1))) + for i in range(num_buckets)] + bucket_keys[-1] = max(bucket_keys[-1], max_lengths) + return bucket_keys + + +class Sampler(object): + """Base class for samplers. + + All samplers should subclass `Sampler` and define `__iter__` and `__len__` + methods. + """ + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class FixedBucketSampler(Sampler): + r"""Assign each data sample to a fixed bucket based on its length. + The bucket keys are either given or generated from the input sequence lengths. + + Parameters + ---------- + lengths : list of int or list of tuple/list of int + The length of the sequences in the input data sample. + batch_size : int + The batch size of the sampler. + num_buckets : int or None, default 10 + The number of buckets. This will not be used if bucket_keys is set. + bucket_keys : None or list of int or list of tuple, default None + The keys that will be used to create the buckets. It should usually be the lengths of the + sequences. If it is None, the bucket_keys will be generated based on the maximum + lengths of the data. + ratio : float, default 0 + Ratio to scale up the batch size of smaller buckets. + Assume the :math:`i` th key is :math:`K_i` , + the default batch size is :math:`B` , the ratio to scale the batch size is + :math:`\alpha` and + the batch size corresponds to the :math:`i` th bucket is :math:`B_i` . We have: + + .. math:: + + B_i = \max(\alpha B \times \frac{\max_j sum(K_j)}{sum(K_i)}, B) + + Thus, setting this to a value larger than 0, like 0.5, will scale up the batch size of the + smaller buckets. + shuffle : bool, default False + Whether to shuffle the batches. + use_average_length : bool, default False + False: each batch contains batch_size sequences, number of sequence elements varies. + True: each batch contains batch_size elements, number of sequences varies. In this case, + ratio option is ignored. + num_shards : int, default 0 + If num_shards > 0, the sampled batch is split into num_shards smaller batches. + The output will have structure of list(list(int)). + If num_shards = 0, the output will have structure of list(int). + This is useful in multi-gpu training and can potentially reduce the number of paddings. + In general, it is set to the number of gpus. + bucket_scheme : BucketScheme, default ConstWidthBucket + It is used to generate bucket keys. It supports: + ConstWidthBucket: all the buckets have the same width + LinearWidthBucket: the width of ith bucket follows :math:`w_i = \alpha * i + 1` + ExpWidthBucket: the width of ith bucket follows + :math:`w_i` = bucket_len_step :math:`* w_{i-1}` + Examples + -------- + >>> lengths = [np.random.randint(1, 100) for _ in range(1000)] + >>> sampler = FixedBucketSampler(lengths, 8, ratio=0.5) + >>> print(sampler.stats()) + FixedBucketSampler: + -etc- + """ + + def __init__(self, lengths, batch_size, num_buckets=10, bucket_keys=None, + ratio=0, shuffle=False, use_average_length=False, num_shards=0, + bucket_scheme=ConstWidthBucket()): + assert len(lengths) > 0, 'FixedBucketSampler does not support empty lengths.' + assert batch_size > 0, 'Batch size must be larger than 0.' + assert ratio >= 0, 'batch size scaling ratio cannot be negative.' + self._batch_size = batch_size + self._ratio = ratio + self._lengths = np.array(lengths, dtype=np.int32) + if self._lengths.ndim == 1: + self._single_element = True + attr_num = 1 + else: + assert self._lengths.ndim == 2, \ + 'Elements in lengths must be either int or tuple/list of int. ' \ + 'Received lengths=%s' % str(lengths) + self._single_element = False + attr_num = self._lengths.shape[1] + self._shuffle = shuffle + self._num_shards = num_shards + self._bucket_scheme = bucket_scheme + max_lengths = self._lengths.max(axis=0) + min_lengths = self._lengths.min(axis=0) + if self._single_element: + assert min_lengths > 0, 'Sequence lengths must all be larger than 0.' + else: + for _, ele in enumerate(min_lengths): + assert ele > 0, 'Sequence lengths must all be larger than 0.' + # Generate the buckets + if bucket_keys is None: + assert num_buckets > 0, 'num_buckets must be set when bucket_keys is None. Received ' \ + 'num_buckets=%d' % num_buckets + bucket_keys = bucket_scheme(max_lengths, min_lengths, num_buckets) + else: + if num_buckets is not None: + warnings.warn('num_buckets will not be used if bucket_keys is not None. ' + 'bucket_keys=%s, num_buckets=%d' % (str(bucket_keys), num_buckets)) + assert len(bucket_keys) > 0 + if self._single_element: + assert isinstance(bucket_keys[0], int) + else: + assert isinstance(bucket_keys[0], tuple) + assert len(bucket_keys[0]) == attr_num + bucket_keys = sorted(set(bucket_keys)) + # Assign instances to buckets + bucket_sample_ids = _match_bucket_keys(bucket_keys, self._lengths) + unused_bucket_keys = [key for key, sample_ids in zip(bucket_keys, bucket_sample_ids) + if len(sample_ids) == 0] + if len(unused_bucket_keys) > 0: + warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' % + str(unused_bucket_keys)) + # Remove empty buckets + self._bucket_keys = [key for key, sample_ids in zip(bucket_keys, bucket_sample_ids) + if len(sample_ids) > 0] + + self._bucket_sample_ids = [sample_ids for sample_ids in bucket_sample_ids + if len(sample_ids) > 0] + if not use_average_length: + scale_up_keys = [key if self._single_element else sum(key) for key + in self._bucket_keys] + max_scale_up_key = max(scale_up_keys) + self._bucket_batch_sizes = [max(int(max_scale_up_key / float(scale_up_key) + * self._ratio * batch_size), batch_size) + for scale_up_key in scale_up_keys] + else: + if ratio > 0.: + warnings.warn('ratio=%f is ignored when use_average_length is True' % self._ratio) + bucket_average_lengths, bucket_length_stds = _bucket_stats(self._bucket_sample_ids, + self._lengths) + self._bucket_batch_sizes = [max(int(batch_size / (average_length + length_std)), 1) + for average_length, length_std + in zip(bucket_average_lengths, bucket_length_stds)] + self._batch_infos = [] + for bucket_id, sample_ids, bucket_batch_size in \ + zip(range(len(self._bucket_keys) - 1, -1, -1), + self._bucket_sample_ids[::-1], + self._bucket_batch_sizes[::-1]): + for i in range(0, len(sample_ids), bucket_batch_size): + self._batch_infos.append((bucket_id, i)) + + if self._num_shards > 0: + self._sampler_size = int(math.ceil(len(self._batch_infos) / float(self._num_shards))) + else: + self._sampler_size = len(self._batch_infos) + + def __iter__(self): + if self._shuffle: + np.random.shuffle(self._batch_infos) + for bucket_id in range(len(self._bucket_keys)): + np.random.shuffle(self._bucket_sample_ids[bucket_id]) + + if self._num_shards > 0: + for batch_idx in range(0, len(self._batch_infos), self._num_shards): + if batch_idx + self._num_shards > len(self._batch_infos): + batch_idx = len(self._batch_infos) - self._num_shards + batch = self._batch_infos[batch_idx: batch_idx + self._num_shards] + bucket_ids, batch_begins = list(zip(*batch)) + batch_sizes = [self._bucket_batch_sizes[bucket_id] for bucket_id in bucket_ids] + batch_ends = [min(batch_begin + batch_size, + len(self._bucket_sample_ids[bucket_id])) + for bucket_id, batch_begin, batch_size in zip(bucket_ids, + batch_begins, + batch_sizes)] + yield [self._bucket_sample_ids[bucket_id][batch_begin:batch_end] + for bucket_id, batch_begin, batch_end in zip(bucket_ids, + batch_begins, + batch_ends)] + else: + for bucket_id, batch_begin in self._batch_infos: + batch_size = self._bucket_batch_sizes[bucket_id] + batch_end = min(batch_begin + batch_size, len(self._bucket_sample_ids[bucket_id])) + yield self._bucket_sample_ids[bucket_id][batch_begin:batch_end] + + def __len__(self): + return self._sampler_size + + def stats(self): + """Return a string representing the statistics of the bucketing sampler. + + Returns + ------- + ret : str + String representing the statistics of the buckets. + """ + ret = '{name}:\n' \ + ' sample_num={sample_num}, batch_num={batch_num}\n' \ + ' key={bucket_keys}\n' \ + ' cnt={bucket_counts}\n' \ + ' batch_size={bucket_batch_sizes}' \ + .format(name=self.__class__.__name__, + sample_num=len(self._lengths), + batch_num=len(self._batch_infos), + bucket_keys=self._bucket_keys, + bucket_counts=[len(sample_ids) for sample_ids in self._bucket_sample_ids], + bucket_batch_sizes=self._bucket_batch_sizes) + return ret + + +class SortedBucketSampler(Sampler): + r"""Batches are sampled from sorted buckets of data. + + First, partition data in buckets of size `batch_size * mult`. + Each bucket contains `batch_size * mult` elements. The samples inside each bucket are sorted + based on sort_key and then batched. + + Parameters + ---------- + sort_keys : list-like object + The keys to sort the samples. + batch_size : int + Batch size of the sampler. + mult : int or float, default 100 + The multiplier to determine the bucket size. Each bucket will have size `mult * batch_size`. + reverse : bool, default True + Whether to sort in descending order. + shuffle : bool, default False + Whether to shuffle the data. + + Examples + -------- + >>> lengths = [np.random.randint(1, 1000) for _ in range(1000)] + >>> sampler = SortedBucketSampler(lengths, 16) + >>> # The sequence lengths within the batch will be sorted + >>> for i, indices in enumerate(sampler): + ... if i == 0: + ... print([lengths[ele] for ele in indices]) + [-etc-] + """ + + def __init__(self, sort_keys, batch_size, mult=100, reverse=True, shuffle=False): + assert len(sort_keys) > 0 + assert batch_size > 0 + assert mult >= 1, 'Bucket size multiplier must be larger than 1' + self._sort_keys = sort_keys + self._batch_size = batch_size + self._mult = mult + self._total_sample_num = len(self._sort_keys) + self._reverse = reverse + self._shuffle = shuffle + + def __iter__(self): + if self._shuffle: + sample_ids = np.random.permutation(self._total_sample_num) + else: + sample_ids = list(range(self._total_sample_num)) + bucket_size = int(self._mult * self._batch_size) + for bucket_begin in range(0, self._total_sample_num, bucket_size): + bucket_end = min(bucket_begin + bucket_size, self._total_sample_num) + sorted_sample_ids = sorted(sample_ids[bucket_begin:bucket_end], + key=lambda i: self._sort_keys[i], reverse=self._reverse) + batch_begins = list(range(0, len(sorted_sample_ids), self._batch_size)) + if self._shuffle: + np.random.shuffle(batch_begins) + for batch_begin in batch_begins: + batch_end = min(batch_begin + self._batch_size, len(sorted_sample_ids)) + yield sorted_sample_ids[batch_begin:batch_end] + + def __len__(self): + return (len(self._sort_keys) + self._batch_size - 1) // self._batch_size From a411c5bbd9db428bb9b22470df3d914b1e7d8d08 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 10:15:48 +0800 Subject: [PATCH 04/20] [feat] add torch utils --- EduKTM/utils/torch_utils/__init__.py | 5 +++++ EduKTM/utils/{torch_utils.py => torch_utils/functional.py} | 0 2 files changed, 5 insertions(+) create mode 100644 EduKTM/utils/torch_utils/__init__.py rename EduKTM/utils/{torch_utils.py => torch_utils/functional.py} (100%) diff --git a/EduKTM/utils/torch_utils/__init__.py b/EduKTM/utils/torch_utils/__init__.py new file mode 100644 index 0000000..3aa1ec3 --- /dev/null +++ b/EduKTM/utils/torch_utils/__init__.py @@ -0,0 +1,5 @@ +# coding: utf-8 +# 2021/5/25 @ tongshiwei + +from .extlib import * +from .functional import * diff --git a/EduKTM/utils/torch_utils.py b/EduKTM/utils/torch_utils/functional.py similarity index 100% rename from EduKTM/utils/torch_utils.py rename to EduKTM/utils/torch_utils/functional.py From 1f38c1d054ae00afde466aee5e3ae4d18bff1261 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 12:14:19 +0800 Subject: [PATCH 05/20] [docs] add examples for DKT+ --- examples/DKT+/DKT+.ipynb | 199 ++++++++++++++++++++++++++++ examples/DKT+/DKT+.py | 2 + examples/DKT+/prepare_dataset.ipynb | 70 ++++++++++ 3 files changed, 271 insertions(+) create mode 100644 examples/DKT+/DKT+.ipynb create mode 100644 examples/DKT+/DKT+.py create mode 100644 examples/DKT+/prepare_dataset.ipynb diff --git a/examples/DKT+/DKT+.ipynb b/examples/DKT+/DKT+.ipynb new file mode 100644 index 0000000..16cd5c6 --- /dev/null +++ b/examples/DKT+/DKT+.ipynb @@ -0,0 +1,199 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Deep Knowledge Tracing Plus (DKT+)\n", + "\n", + "This notebook will show you how to train and use the DKT+.\n", + "First, we will show how to get the data (here we use a0910 as the dataset).\n", + "Then we will show how to train a DKT+ and perform the parameters persistence.\n", + "At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n", + "\n", + "The script version could be found in [DKT+.py](DKT+.ipynb)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Data Preparation\n", + "\n", + "Before we process the data, we need to first acquire the dataset which is shown in [prepare_dataset.ipynb](prepare_dataset.ipynb)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "reading data from ../../data/a0910c/train.json: 3966it [00:00, 26335.00it/s]\n", + "batchify: 100%|██████████| 130/130 [00:00<00:00, 1343.40it/s]\n", + "reading data from ../../data/a0910c/valid.json: 472it [00:00, 47324.16it/s]\n", + "E:\\Program\\EduKTM\\EduKTM\\utils\\torch_utils\\extlib\\sampler.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[55, 58, 59, 61, 65, 69, 74, 76, 77, 79, 80, 88, 90, 94, 95, 96, 99]\n", + " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", + "batchify: 100%|██████████| 84/84 [00:00<00:00, 6016.63it/s]\n", + "reading data from ../../data/a0910c/test.json: 1088it [00:00, 21999.30it/s]\n", + "E:\\Program\\EduKTM\\EduKTM\\utils\\torch_utils\\extlib\\sampler.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[73, 88]\n", + " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", + "batchify: 100%|██████████| 101/101 [00:00<00:00, 3616.89it/s]\n" + ] + } + ], + "source": [ + "from EduKTM.DKTPlus import etl\n", + "batch_size = 64\n", + "train = etl(\"../../data/a0910c/train.json\", batch_size)\n", + "valid = etl(\"../../data/a0910c/valid.json\", batch_size)\n", + "test = etl(\"../../data/a0910c/test.json\", batch_size)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Training and Persistence" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "import logging\n", + "logging.getLogger().setLevel(logging.INFO)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|██████████| 130/130 [00:03<00:00, 36.62it/s]\n", + "evaluating: 100%|██████████| 84/84 [00:00<00:00, 188.42it/s]\n", + "Epoch 1: 100%|██████████| 130/130 [00:03<00:00, 35.92it/s]\n", + "evaluating: 100%|██████████| 84/84 [00:00<00:00, 197.71it/s]\n", + "INFO:root:save parameters to dkt+.params\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Epoch 0] SLMoss: 0.490312\n", + "[Epoch 0] auc: 0.655370, accuracy: 0.681618\n", + "[Epoch 1] SLMoss: 0.226885\n", + "[Epoch 1] auc: 0.671076, accuracy: 0.674871\n" + ] + } + ], + "source": [ + "from EduKTM import DKTPlus\n", + "\n", + "# dkt_plus = DKTPlus(ku_num=146, hidden_num=100, loss_params={\"lr\": 0.1, \"lw1\": 0.5, \"lw2\": 0.5})\n", + "dkt_plus = DKTPlus(ku_num=146, hidden_num=100, loss_params={\"lr\": 0.1, \"lw1\": 0.5, \"lw2\": 0.5})\n", + "dkt_plus.train(train, valid, epoch=2)\n", + "dkt_plus.save(\"dkt+.params\")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Loading and Testing" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:load parameters from dkt+.params\n", + "evaluating: 100%|██████████| 101/101 [00:00<00:00, 125.49it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "auc: 0.678657, accuracy: 0.674521\n" + ] + } + ], + "source": [ + "dkt_plus.load(\"dkt+.params\")\n", + "auc, accuracy = dkt_plus.eval(test)\n", + "print(\"auc: %.6f, accuracy: %.6f\" % (auc, accuracy))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/DKT+/DKT+.py b/examples/DKT+/DKT+.py new file mode 100644 index 0000000..9e4f27c --- /dev/null +++ b/examples/DKT+/DKT+.py @@ -0,0 +1,2 @@ +# coding: utf-8 +# 2021/5/26 @ tongshiwei diff --git a/examples/DKT+/prepare_dataset.ipynb b/examples/DKT+/prepare_dataset.ipynb new file mode 100644 index 0000000..81779c8 --- /dev/null +++ b/examples/DKT+/prepare_dataset.ipynb @@ -0,0 +1,70 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": true, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/readme.txt is saved as ..\\..\\data\\a0910c\\readme.txt\n", + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/test.json is saved as ..\\..\\data\\a0910c\\test.json\n", + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/train.json is saved as ..\\..\\data\\a0910c\\train.json\n", + "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/valid.json is saved as ..\\..\\data\\a0910c\\valid.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ..\\..\\data\\a0910c\\readme.txt 100.00%: 21 | 21\n", + "Downloading ..\\..\\data\\a0910c\\test.json 100.00%: 477005 | 477005\n", + "Downloading ..\\..\\data\\a0910c\\train.json 100.00%: 1807148 | 1807148\n", + "Downloading ..\\..\\data\\a0910c\\valid.json 100.00%: 222345 | 222345\n" + ] + }, + { + "data": { + "text/plain": "'../../data'" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from EduData import get_data\n", + "\n", + "get_data(\"ktbd-a0910c\", \"../../data\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From c81115194593407c154ed95ce30b97dc07646904 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 12:14:32 +0800 Subject: [PATCH 06/20] [feat] add DKT+ --- EduKTM/DKTPlus/DKTPlus.py | 118 +++++++++++++++++++++++++++++++++++++ EduKTM/DKTPlus/__init__.py | 5 ++ EduKTM/DKTPlus/etl.py | 69 ++++++++++++++++++++++ EduKTM/__init__.py | 1 + 4 files changed, 193 insertions(+) create mode 100644 EduKTM/DKTPlus/DKTPlus.py create mode 100644 EduKTM/DKTPlus/__init__.py create mode 100644 EduKTM/DKTPlus/etl.py diff --git a/EduKTM/DKTPlus/DKTPlus.py b/EduKTM/DKTPlus/DKTPlus.py new file mode 100644 index 0000000..5bb1ba0 --- /dev/null +++ b/EduKTM/DKTPlus/DKTPlus.py @@ -0,0 +1,118 @@ +# coding: utf-8 +# 2021/5/25 @ tongshiwei + +import logging +import torch +from EduKTM import KTM +from torch import nn +import torch.nn.functional as F +from tqdm import tqdm +from EduKTM.utils import sequence_mask, SLMLoss, tensor2list, pick +from sklearn.metrics import roc_auc_score, accuracy_score +import numpy as np + + +class DKTNet(nn.Module): + def __init__(self, ku_num, hidden_num, add_embedding_layer=False, dropout=0.0, **kwargs): + super(DKTNet, self).__init__() + self.ku_num = ku_num + self.hidden_dim = hidden_num + self.output_dim = ku_num + if add_embedding_layer is True: + self.embeddings = nn.Sequential( + nn.Embedding(ku_num * 2, kwargs["latent_dim"]), + nn.Dropout(kwargs.get("embedding_dropout", 0.2)) + ) + rnn_input_dim = kwargs["latent_dim"] + else: + self.embeddings = lambda x: F.one_hot(x, num_classes=self.output_dim * 2).float() + rnn_input_dim = ku_num * 2 + + self.rnn = nn.RNN(rnn_input_dim, hidden_num, 1, batch_first=True, nonlinearity='tanh') + self.fc = nn.Linear(self.hidden_dim, self.output_dim) + self.dropout = nn.Dropout(dropout) + self.sig = nn.Sigmoid() + + def forward(self, responses, mask=None, begin_state=None): + responses = self.embeddings(responses) + output, hn = self.rnn(responses) + output = self.sig(self.fc(self.dropout(output))) + if mask is not None: + output = sequence_mask(output, mask) + return output, hn + + +class DKTPlus(KTM): + def __init__(self, ku_num, hidden_num, net_params: dict = None, loss_params=None): + super(DKTPlus, self).__init__() + self.dkt_net = DKTNet( + ku_num, + hidden_num, + **(net_params if net_params is not None else {}) + ) + self.loss_params = loss_params if loss_params is not None else {} + + def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: + loss_function = SLMLoss(**self.loss_params) + + trainer = torch.optim.Adam(self.dkt_net.parameters(), lr) + + for e in range(epoch): + losses = [] + for (data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e): + # convert to device + data: torch.Tensor = data.to(device) + data_mask: torch.Tensor = data_mask.to(device) + label: torch.Tensor = label.to(device) + pick_index: torch.Tensor = pick_index.to(device) + label_mask: torch.Tensor = label_mask.to(device) + + # real training + predicted_response, _ = self.dkt_net(data, data_mask) + loss = loss_function(predicted_response, pick_index, label, label_mask) + + # back propagation + trainer.zero_grad() + loss.backward() + trainer.step() + + losses.append(loss.mean().item()) + print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses)))) + + if test_data is not None: + auc, accuracy = self.eval(test_data) + print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy)) + + def eval(self, test_data, device="cpu") -> tuple: + self.dkt_net.eval() + y_true = [] + y_pred = [] + + for (data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"): + # convert to device + data: torch.Tensor = data.to(device) + data_mask: torch.Tensor = data_mask.to(device) + label: torch.Tensor = label.to(device) + pick_index: torch.Tensor = pick_index.to(device) + label_mask: torch.Tensor = label_mask.to(device) + + # real evaluating + output, _ = self.dkt_net(data, data_mask) + output = output[:, :-1] + output = pick(output, pick_index.to(output.device)) + pred = tensor2list(output) + label = tensor2list(label) + for i, length in enumerate(label_mask.numpy().tolist()): + length = int(length) + y_true.extend(label[i][:length]) + y_pred.extend(pred[i][:length]) + self.dkt_net.train() + return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5) + + def save(self, filepath) -> ...: + torch.save(self.dkt_net.state_dict(), filepath) + logging.info("save parameters to %s" % filepath) + + def load(self, filepath): + self.dkt_net.load_state_dict(torch.load(filepath)) + logging.info("load parameters from %s" % filepath) diff --git a/EduKTM/DKTPlus/__init__.py b/EduKTM/DKTPlus/__init__.py new file mode 100644 index 0000000..3257cfa --- /dev/null +++ b/EduKTM/DKTPlus/__init__.py @@ -0,0 +1,5 @@ +# coding: utf-8 +# 2021/5/25 @ tongshiwei + +from .DKTPlus import DKTPlus +from .etl import etl diff --git a/EduKTM/DKTPlus/etl.py b/EduKTM/DKTPlus/etl.py new file mode 100644 index 0000000..64d32a6 --- /dev/null +++ b/EduKTM/DKTPlus/etl.py @@ -0,0 +1,69 @@ +# coding: utf-8 +# 2021/5/25 @ tongshiwei + +import torch +import json +from tqdm import tqdm +from EduKTM.utils.torch_utils import PadSequence, FixedBucketSampler + + +def extract(data_src): + responses = [] + step = 200 + with open(data_src) as f: + for line in tqdm(f, "reading data from %s" % data_src): + data = json.loads(line) + for i in range(0, len(data), step): + if len(data[i: i + step]) < 2: + continue + responses.append(data[i: i + step]) + + return responses + + +def transform(raw_data, batch_size, num_buckets=100): + # 定义数据转换接口 + # raw_data --> batch_data + + responses = raw_data + + batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets) + batch = [] + + def index(r): + correct = 0 if r[1] <= 0 else 1 + return r[0] * 2 + correct + + for batch_idx in tqdm(batch_idxes, "batchify"): + batch_rs = [] + batch_pick_index = [] + batch_labels = [] + for idx in batch_idx: + batch_rs.append([index(r) for r in responses[idx]]) + if len(responses[idx]) <= 1: + pick_index, labels = [], [] + else: + pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]]) + batch_pick_index.append(list(pick_index)) + batch_labels.append(list(labels)) + + max_len = max([len(rs) for rs in batch_rs]) + padder = PadSequence(max_len, pad_val=0) + batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs]) + + max_len = max([len(rs) for rs in batch_labels]) + padder = PadSequence(max_len, pad_val=0) + batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels]) + batch_pick_index = [padder(pick_index) for pick_index in batch_pick_index] + # Load + batch.append( + [torch.tensor(batch_rs), torch.tensor(data_mask), torch.tensor(batch_labels), + torch.tensor(batch_pick_index), + torch.tensor(label_mask)]) + + return batch + + +def etl(data_src, batch_size, **kwargs): + raw_data = extract(data_src) + return transform(raw_data, batch_size, **kwargs) diff --git a/EduKTM/__init__.py b/EduKTM/__init__.py index f0a55e0..9b37be2 100644 --- a/EduKTM/__init__.py +++ b/EduKTM/__init__.py @@ -4,3 +4,4 @@ from .meta import KTM from .KPT import KPT from .DKT import DKT +from .DKTPlus import DKTPlus From 7902852aa1ddc19e6664ee9b6997c5b0cbbd71a7 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 12:14:46 +0800 Subject: [PATCH 07/20] [docs] for DKT+ --- docs/DKT+.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 docs/DKT+.md diff --git a/docs/DKT+.md b/docs/DKT+.md new file mode 100644 index 0000000..3823163 --- /dev/null +++ b/docs/DKT+.md @@ -0,0 +1,14 @@ +# Deep Knowledge Tracing Plus (DKT+) + +If the reader wants to know the details of DKT+, please refer to the Appendix of the paper: *[Addressing Two Problems in Deep Knowledge Tracing via +Prediction-Consistent Regularization](https://arxiv.org/pdf/1806.02180.pdf)*. + +```bibtex +@inproceedings{yeung2018addressing, + title={Addressing two problems in deep knowledge tracing via prediction-consistent regularization}, + author={Yeung, Chun-Kit and Yeung, Dit-Yan}, + booktitle={Proceedings of the Fifth Annual ACM Conference on Learning at Scale}, + pages={1--10}, + year={2018} +} +``` \ No newline at end of file From 90217b59d744ba7bcab9fc5b324cc5e9672afa2a Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 12:15:02 +0800 Subject: [PATCH 08/20] [chore] add gitignore item --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 75d3786..ea2af58 100644 --- a/.gitignore +++ b/.gitignore @@ -109,4 +109,5 @@ venv.bak/ # Pyre type checker .pyre/ -# User Definition \ No newline at end of file +# User Definition +data/ \ No newline at end of file From 0dfe3ab70c535d879f048407c171ea21e0e40e20 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 12:15:25 +0800 Subject: [PATCH 09/20] [test] ignore external library --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index c91e312..bf35ca8 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,6 @@ [pytest] # For pytest usage, refer to https://hb4dsai.readthedocs.io/zh/latest/Architecture/Test.html -norecursedirs = docs *build* trash dev examples +norecursedirs = docs *build* trash dev examples EduKTM/utils/torch_utils/extlib # Deal with marker warnings markers = From 4eb10accc5f9c124436bdf84933d47e9a3c650eb Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 16:46:51 +0800 Subject: [PATCH 10/20] [docs] update examples --- examples/DKT+/DKT+.ipynb | 37 ++++++++++++++++++------------------- examples/DKT+/DKT+.py | 19 +++++++++++++++++++ 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/examples/DKT+/DKT+.ipynb b/examples/DKT+/DKT+.ipynb index 16cd5c6..c251e7c 100644 --- a/examples/DKT+/DKT+.ipynb +++ b/examples/DKT+/DKT+.ipynb @@ -38,16 +38,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "reading data from ../../data/a0910c/train.json: 3966it [00:00, 26335.00it/s]\n", - "batchify: 100%|██████████| 130/130 [00:00<00:00, 1343.40it/s]\n", - "reading data from ../../data/a0910c/valid.json: 472it [00:00, 47324.16it/s]\n", + "reading data from ../../data/a0910c/train.json: 3966it [00:00, 25996.05it/s]\n", + "batchify: 100%|██████████| 130/130 [00:00<00:00, 1372.38it/s]\n", + "reading data from ../../data/a0910c/valid.json: 472it [00:00, 39390.19it/s]\n", "E:\\Program\\EduKTM\\EduKTM\\utils\\torch_utils\\extlib\\sampler.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[55, 58, 59, 61, 65, 69, 74, 76, 77, 79, 80, 88, 90, 94, 95, 96, 99]\n", " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", - "batchify: 100%|██████████| 84/84 [00:00<00:00, 6016.63it/s]\n", - "reading data from ../../data/a0910c/test.json: 1088it [00:00, 21999.30it/s]\n", + "batchify: 100%|██████████| 84/84 [00:00<00:00, 6005.04it/s]\n", + "reading data from ../../data/a0910c/test.json: 1088it [00:00, 17315.92it/s]\n", "E:\\Program\\EduKTM\\EduKTM\\utils\\torch_utils\\extlib\\sampler.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[73, 88]\n", " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", - "batchify: 100%|██████████| 101/101 [00:00<00:00, 3616.89it/s]\n" + "batchify: 100%|██████████| 101/101 [00:00<00:00, 3492.14it/s]\n" ] } ], @@ -94,16 +94,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Epoch 0: 100%|██████████| 130/130 [00:03<00:00, 36.62it/s]\n", - "evaluating: 100%|██████████| 84/84 [00:00<00:00, 188.42it/s]\n", - "Epoch 1: 100%|██████████| 130/130 [00:03<00:00, 35.92it/s]\n", - "evaluating: 100%|██████████| 84/84 [00:00<00:00, 197.71it/s]\n", + "Epoch 0: 100%|██████████| 130/130 [00:06<00:00, 21.48it/s]\n", + "evaluating: 100%|██████████| 84/84 [00:00<00:00, 193.61it/s]\n", + "Epoch 1: 100%|██████████| 130/130 [00:05<00:00, 21.93it/s]\n", + "evaluating: 100%|██████████| 84/84 [00:00<00:00, 199.11it/s]\n", "INFO:root:save parameters to dkt+.params\n" ] }, @@ -111,17 +111,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Epoch 0] SLMoss: 0.490312\n", - "[Epoch 0] auc: 0.655370, accuracy: 0.681618\n", - "[Epoch 1] SLMoss: 0.226885\n", - "[Epoch 1] auc: 0.671076, accuracy: 0.674871\n" + "[Epoch 0] SLMoss: 0.553947\n", + "[Epoch 0] auc: 0.661187, accuracy: 0.688492\n", + "[Epoch 1] SLMoss: 0.278718\n", + "[Epoch 1] auc: 0.672982, accuracy: 0.679581\n" ] } ], "source": [ "from EduKTM import DKTPlus\n", "\n", - "# dkt_plus = DKTPlus(ku_num=146, hidden_num=100, loss_params={\"lr\": 0.1, \"lw1\": 0.5, \"lw2\": 0.5})\n", "dkt_plus = DKTPlus(ku_num=146, hidden_num=100, loss_params={\"lr\": 0.1, \"lw1\": 0.5, \"lw2\": 0.5})\n", "dkt_plus.train(train, valid, epoch=2)\n", "dkt_plus.save(\"dkt+.params\")" @@ -144,21 +143,21 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:load parameters from dkt+.params\n", - "evaluating: 100%|██████████| 101/101 [00:00<00:00, 125.49it/s]\n" + "evaluating: 100%|██████████| 101/101 [00:00<00:00, 129.17it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "auc: 0.678657, accuracy: 0.674521\n" + "auc: 0.681138, accuracy: 0.678663\n" ] } ], diff --git a/examples/DKT+/DKT+.py b/examples/DKT+/DKT+.py index 9e4f27c..a9eb020 100644 --- a/examples/DKT+/DKT+.py +++ b/examples/DKT+/DKT+.py @@ -1,2 +1,21 @@ # coding: utf-8 # 2021/5/26 @ tongshiwei +import logging +from EduKTM.DKTPlus import etl + +from EduKTM import DKTPlus + +batch_size = 64 +train = etl("../../data/a0910c/train.json", batch_size) +valid = etl("../../data/a0910c/valid.json", batch_size) +test = etl("../../data/a0910c/test.json", batch_size) + +logging.getLogger().setLevel(logging.INFO) + +dkt_plus = DKTPlus(ku_num=146, hidden_num=100, loss_params={"lr": 0.1, "lw1": 0.5, "lw2": 0.5}) +dkt_plus.train(train, valid, epoch=2) +dkt_plus.save("dkt+.params") + +dkt_plus.load("dkt+.params") +auc, accuracy = dkt_plus.eval(test) +print("auc: %.6f, accuracy: %.6f" % (auc, accuracy)) From 82a9ab4c27d6823d97b70bfbe018a99810c78092 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 16:47:08 +0800 Subject: [PATCH 11/20] [feat] add util functions for tests --- EduKTM/utils/tests.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 EduKTM/utils/tests.py diff --git a/EduKTM/utils/tests.py b/EduKTM/utils/tests.py new file mode 100644 index 0000000..7e3fe9f --- /dev/null +++ b/EduKTM/utils/tests.py @@ -0,0 +1,16 @@ +# coding: utf-8 +# 2021/5/26 @ tongshiwei + +def pseudo_data_generation(ku_num, record_num=10, max_length=20): + # 在这里定义测试用伪数据流 + import random + random.seed(10) + + raw_data = [ + [ + (random.randint(0, ku_num - 1), random.randint(-1, 1)) + for _ in range(random.randint(2, max_length)) + ] for _ in range(record_num) + ] + + return raw_data From 4a815eda4c6d1373baca27d858e27c1c95691251 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 16:50:29 +0800 Subject: [PATCH 12/20] [test] for DKT+ --- tests/dkt+/__init__.py | 2 ++ tests/dkt+/conftest.py | 19 +++++++++++++++++++ tests/dkt+/test_dkt_plus.py | 13 +++++++++++++ 3 files changed, 34 insertions(+) create mode 100644 tests/dkt+/__init__.py create mode 100644 tests/dkt+/conftest.py create mode 100644 tests/dkt+/test_dkt_plus.py diff --git a/tests/dkt+/__init__.py b/tests/dkt+/__init__.py new file mode 100644 index 0000000..9e4f27c --- /dev/null +++ b/tests/dkt+/__init__.py @@ -0,0 +1,2 @@ +# coding: utf-8 +# 2021/5/26 @ tongshiwei diff --git a/tests/dkt+/conftest.py b/tests/dkt+/conftest.py new file mode 100644 index 0000000..bfb816e --- /dev/null +++ b/tests/dkt+/conftest.py @@ -0,0 +1,19 @@ +# coding: utf-8 +# 2021/5/26 @ tongshiwei +import pytest + +from EduKTM.utils.tests import pseudo_data_generation +from EduKTM.DKTPlus.etl import transform + + +@pytest.fixture(scope="package") +def conf(): + ques_num = 10 + hidden_num = 10 + return ques_num, hidden_num + + +@pytest.fixture(scope="package") +def data(conf): + ques_num, _ = conf + return transform(pseudo_data_generation(ques_num), 32) diff --git a/tests/dkt+/test_dkt_plus.py b/tests/dkt+/test_dkt_plus.py new file mode 100644 index 0000000..0520bd5 --- /dev/null +++ b/tests/dkt+/test_dkt_plus.py @@ -0,0 +1,13 @@ +# coding: utf-8 +# 2021/5/26 @ tongshiwei + +from EduKTM import DKTPlus + + +def test_train(data, conf, tmp_path): + ku_num, hidden_size = conf + dkt_plus = DKTPlus(ku_num, hidden_size) + dkt_plus.train(data, test_data=data, epoch=2) + filepath = tmp_path / "dkt+.params" + dkt_plus.save(filepath) + dkt_plus.load(filepath) From 302b24958d6f190d9582fef4ce6487ffde0e101d Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 17:15:32 +0800 Subject: [PATCH 13/20] [test] ignorance --- EduKTM/utils/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EduKTM/utils/loss.py b/EduKTM/utils/loss.py index 2687388..6b0a5f8 100644 --- a/EduKTM/utils/loss.py +++ b/EduKTM/utils/loss.py @@ -61,7 +61,7 @@ def forward(self, pred_rs, pick_index, label, label_mask): return loss -class LogisticMaskLoss(nn.Module): +class LogisticMaskLoss(nn.Module): # pragma: no cover """ Notes ----- From 4d8539055b0350bffad7a1a9eb0ee81e96af1d9b Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 17:15:44 +0800 Subject: [PATCH 14/20] [test] coverage --- tests/dkt+/test_dkt_plus.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/dkt+/test_dkt_plus.py b/tests/dkt+/test_dkt_plus.py index 0520bd5..32178d1 100644 --- a/tests/dkt+/test_dkt_plus.py +++ b/tests/dkt+/test_dkt_plus.py @@ -1,12 +1,21 @@ # coding: utf-8 # 2021/5/26 @ tongshiwei +import pytest from EduKTM import DKTPlus -def test_train(data, conf, tmp_path): - ku_num, hidden_size = conf - dkt_plus = DKTPlus(ku_num, hidden_size) +@pytest.mark.parametrize("lr", [0, 0.1]) +@pytest.mark.parametrize("lw1", [0, 0.5]) +@pytest.mark.parametrize("lw2", [0, 0.5]) +@pytest.mark.parametrize("add_embedding_layer", [True, False]) +def test_train(data, conf, tmp_path, lr, lw1, lw2, add_embedding_layer): + ku_num, hidden_num = conf + dkt_plus = DKTPlus( + ku_num, hidden_num, + net_params={"add_embedding_layer": add_embedding_layer}, + loss_params={"lr": lr, "lw1": lw1, "lw2": lw2} + ) dkt_plus.train(data, test_data=data, epoch=2) filepath = tmp_path / "dkt+.params" dkt_plus.save(filepath) From b31ad08b1713fecca279904c677ad89662184090 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 17:16:02 +0800 Subject: [PATCH 15/20] [fix] missing latent dim --- EduKTM/DKTPlus/DKTPlus.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/EduKTM/DKTPlus/DKTPlus.py b/EduKTM/DKTPlus/DKTPlus.py index 5bb1ba0..cf22148 100644 --- a/EduKTM/DKTPlus/DKTPlus.py +++ b/EduKTM/DKTPlus/DKTPlus.py @@ -13,17 +13,18 @@ class DKTNet(nn.Module): - def __init__(self, ku_num, hidden_num, add_embedding_layer=False, dropout=0.0, **kwargs): + def __init__(self, ku_num, hidden_num, add_embedding_layer=False, embedding_dim=None, dropout=0.0, **kwargs): super(DKTNet, self).__init__() self.ku_num = ku_num self.hidden_dim = hidden_num self.output_dim = ku_num if add_embedding_layer is True: + embedding_dim = self.hidden_dim if embedding_dim is None else embedding_dim self.embeddings = nn.Sequential( - nn.Embedding(ku_num * 2, kwargs["latent_dim"]), + nn.Embedding(ku_num * 2, embedding_dim), nn.Dropout(kwargs.get("embedding_dropout", 0.2)) ) - rnn_input_dim = kwargs["latent_dim"] + rnn_input_dim = embedding_dim else: self.embeddings = lambda x: F.one_hot(x, num_classes=self.output_dim * 2).float() rnn_input_dim = ku_num * 2 From 5fc136f30074686ec91d599ae38973217ff4223f Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 17:16:36 +0800 Subject: [PATCH 16/20] [test] ignorance --- EduKTM/DKTPlus/etl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/EduKTM/DKTPlus/etl.py b/EduKTM/DKTPlus/etl.py index 64d32a6..2182c7a 100644 --- a/EduKTM/DKTPlus/etl.py +++ b/EduKTM/DKTPlus/etl.py @@ -7,7 +7,7 @@ from EduKTM.utils.torch_utils import PadSequence, FixedBucketSampler -def extract(data_src): +def extract(data_src): # pragma: no cover responses = [] step = 200 with open(data_src) as f: @@ -40,7 +40,7 @@ def index(r): batch_labels = [] for idx in batch_idx: batch_rs.append([index(r) for r in responses[idx]]) - if len(responses[idx]) <= 1: + if len(responses[idx]) <= 1: # pragma: no cover pick_index, labels = [], [] else: pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]]) @@ -64,6 +64,6 @@ def index(r): return batch -def etl(data_src, batch_size, **kwargs): +def etl(data_src, batch_size, **kwargs): # pragma: no cover raw_data = extract(data_src) return transform(raw_data, batch_size, **kwargs) From fb6d99caea6078b4c39a5b00716ae35ef3c02b26 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 17:16:51 +0800 Subject: [PATCH 17/20] [test] ignore external lib --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index bf1569b..46866d5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,6 @@ [coverage:run] source=EduKTM +omit=EduKTM/utils/torch_utils/extlib/* [coverage:report] exclude_lines = pragma: no cover From 26798e7445f6433dfc4ce2c740e50e89d7048618 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 17:21:24 +0800 Subject: [PATCH 18/20] [feat] update version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8f46ad8..b7df6ab 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='EduKTM', - version='0.0.4', + version='0.0.5', extras_require={ 'test': test_deps, }, From 87f251fb77d8835d74ee76fcb252df72ccd7f9fa Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 17:21:43 +0800 Subject: [PATCH 19/20] [docs] record changes --- CHANGE.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGE.txt b/CHANGE.txt index 38323b6..c1e0a1b 100644 --- a/CHANGE.txt +++ b/CHANGE.txt @@ -1,3 +1,7 @@ +v0.0.5: + * add DKT+ + * add some util functions + v0.0.4: * fix potential ModuleNotFoundError From 7e76aa6ad91211b2112b660d6b21583f82aa08d1 Mon Sep 17 00:00:00 2001 From: Shiwei Tong Date: Wed, 26 May 2021 17:22:00 +0800 Subject: [PATCH 20/20] [docs] add links to DKT+ --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 71809e6..264440b 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ The Model Zoo of Knowledge Tracing Models * [KPT,EKPT](EduKTM/KPT) [[doc]](docs/KPT.md) [[example]](examples/KPT) * [DKT](EduKTM/DKT) [[doc]](docs/DKT.md) [[example]](examples/DKT) +* [DKT+](EduKTM/DKTPlus) [[doc]](docs/DKT+.md) [[example]](examples/DKT+) ## Contribute