Nested-Wasserstein Self-Imitation Learning for Sequence Generation

Ruiyi Zhang*, Changyou Chen, Zhe Gan, Zheng Wen, Wenlin Wang, Lawrence Carin

*Corresponding author for this work

Research output: Chapter in Book/Report/Conference proceedingConference contributionpeer-review

3 Scopus citations

Abstract

Reinforcement learning (RL) has been widely studied for improving sequence-generation models. However, the conventional rewards used for RL training typically cannot capture sufficient semantic information and therefore manifest model bias. Further, the sparse and delayed rewards make RL exploration inefficient. To alleviate these issues, we propose the concept of nested-Wasserstein distance for distributional semantic matching. To further exploit it, a novel nested-Wasserstein self-imitation learning framework is developed, encouraging the model to exploit historical high-reward sequences for enhanced exploration and better semantic matching. Our solution can be understood as approximately executing proximal policy optimization with Wasserstein trust-regions. Experiments on a variety of unconditional and conditional sequence-generation tasks demonstrate the proposed approach consistently leads to improved performance.

Original languageEnglish
Title of host publicationINTERNATIONAL CONFERENCE ON ARTIFICIAL INTELLIGENCE AND STATISTICS, VOL 108
EditorsS Chiappa, R Calandra
PublisherADDISON-WESLEY PUBL CO
Pages422-432
Number of pages11
StatePublished - 2020
Externally publishedYes
Event23rd International Conference on Artificial Intelligence and Statistics (AISTATS) -
Duration: Aug 26 2020Aug 28 2020

Publication series

NameProceedings of Machine Learning Research
PublisherADDISON-WESLEY PUBL CO
Volume108
ISSN (Print)2640-3498

Conference

Conference23rd International Conference on Artificial Intelligence and Statistics (AISTATS)
Period08/26/2008/28/20

Fingerprint

Dive into the research topics of 'Nested-Wasserstein Self-Imitation Learning for Sequence Generation'. Together they form a unique fingerprint.

Cite this