Skip to content

ESCM2模型的IPW代码 #991

Description

@zhjcp

你好,我最近阅读了PaddlePaddle对ESCM2模型的代码实现,我对其中的 IPW 代码部分有点疑问:

    def counterfact_ipw(self, loss_cvr, ctr_num, O, ctr_out_one):
        PS = paddle.multiply(
            ctr_out_one, paddle.cast(
                ctr_num, dtype="float32"))
        min_v = paddle.full_like(PS, 0.000001)
        PS = paddle.maximum(PS, min_v)
        IPS = paddle.reciprocal(PS)
        batch_shape = paddle.full_like(O, 1)
        batch_size = paddle.sum(paddle.cast(
            batch_shape, dtype="float32"),
                                axis=0)
        #TODO this shoud be a hyparameter
        IPS = paddle.clip(IPS, min=-15, max=15)  #online trick 
        IPS = paddle.multiply(IPS, batch_size)
        IPS.stop_gradient = True
        loss_cvr = paddle.multiply(loss_cvr, IPS)
        loss_cvr = paddle.multiply(loss_cvr, O)
        return paddle.mean(loss_cvr)

请问为什么需要乘 "ctr_num" 和 "batch_size" 呢?这跟ESCM2论文中描述的 IPW 公式不太一样?通常来说IPW不就是直接除以 CTR 的概率吗 ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions