核心概念
Decision MetaMamba (DMM) 通過在 Mamba 的輸入層中加入一個標記混合器,並採用殘差加法網路,有效地整合了鄰近步驟和遠距離步驟的信息,從而提升了基於狀態空間模型的離線強化學習決策模型的性能。
這篇研究論文介紹了 Decision MetaMamba (DMM),一種用於離線強化學習的新型序列決策模型。DMM 建立在先進的狀態空間模型 Mamba 的基礎上,並透過修改後的輸入層來提升效能。
背景
傳統的強化學習方法,如基於 Transformer 的模型,通常依賴於位置編碼來處理序列數據,這可能導致行為克隆,並限制了模型對缺乏適當時間步長標註的數據集的適用性。此外,這些模型在處理長序列時可能會遇到計算效率低下的問題。
Mamba 與狀態空間模型
Mamba 是一種最先進的狀態空間模型 (SSM),在各種序列建模任務中展現出超越 Transformer 的性能。Mamba 的架構允許其內部狀態根據輸入進行動態調整,從而實現選擇性信息保留和基於內容的推理。與基於 Transformer 的模型不同,Mamba 不需要位置編碼,從而降低了行為克隆的風險,並增強了模型的泛化能力。
Decision MetaMamba 的創新
DMM 通過在 Mamba 的輸入層中加入一個標記混合器來解決傳統模型的局限性。這個標記混合器旨在融合來自相鄰步驟的信息,從而減輕數據丟失並保留局部關係。此外,DMM 採用了殘差加法網路,以解決 Mamba 塊中殘差乘法可能導致的學習困難。
多模態標記混合器
為了有效處理離線強化學習數據集中的不同輸入模態(狀態、動作和預期回報),DMM 採用了兩種不同的標記混合器:
多模態一維卷積層:在隱藏狀態維度上運作,整合窗口內的相鄰嵌入。
多模態線性層:沿著序列維度整合標記,將連續的序列向量拼接在一起。
實驗結果
在 D4RL MuJoCo、AntMaze 和 Atari 環境中的實驗結果表明,DMM 在性能上優於或與現有模型相當,同時使用的參數數量顯著減少。值得注意的是,DMM 在需要整合鄰近和遠距離序列信息的任務中表現出色,突出了其在處理複雜序列數據方面的優勢。
結論
DMM 是一種基於狀態空間模型的新型強化學習方法,它通過整合多模態標記混合器和殘差加法網路,有效地提升了 Mamba 的性能。DMM 的效率、準確性和泛化能力使其成為離線強化學習中一個有前途的方向。
統計資料
DMM 使用的參數數量比傳統的基於 Transformer 的模型少 90%。
在某些基準測試中,DMM 的推理速度比傳統的 Transformer 模型快五倍。