うめこの開発日記

PCゲームや人工知能関連の話題についての日記

Chainerで複数入力のモデルとデータセットの作り方

Chainer使ってて複数の入力をするモデルを作る必要があって、解決したのでメモ

作りたかったモデルはこんなの

class LSTM_dual_input(Chain):
    def __init__(self):
        super(LSTM_dual_input, self).__init__()
        with self.init_scope():
            self.emb = L.EmbedID(10000, 300, ignore_label=-1)
            self.lstm = L.LSTM(None, 100)
            self.linear = L.Linear(None, 100)
            self.linear2 = L.Linear(None, 2000)
    def reset_state(self):
        self.lstm.reset_state()
    def __call__(self, x1,x2):
        x1_t = F.transpose_sequence(x1)
        for x1_i in x1_t:
            h1 = self.emb(x1_i)
            h2 = self.lstm(h1)
        h3=F.concat([h2, F.relu(self.linear(x2))], axis=1)
        y = self.linear2(h3)
        return y

時系列の入力x1をLSTMでエンコードして、もう一つの入力x2は専用の全結合層に普通に通す。
次の層でLSTMと全結合層と合わせて(concatの部分)、出力層につないでクラス分類するよって構造。

まぁモデルは別にいいんだけど、悩んだのは複数入力する場合のデータの作り方
たぶん簡単にやるならdatasetを次のように作るはず

dataset = [
        (x1_1, label_1),
        (x1_2, label_2),
        (x1_3, label_3)
        ...
        ]

じゃあこれを複数にしよう!って思って俺がやったのがこれ

dataset = [
        ((x1_1, x2_1), label_1),
        ((x1_2, x2_2), label_2),
        ((x1_3, x2_3), label_3)
        ...
        ]

てっきり(データ、ラベル)の形式しかできないんじゃないかと勝手に思い込んでました.....普通に(データ、データ、ラベル)っていけるのね........
正解の複数型のデータセットは以下

dataset = [
        (x1_1, x2_1, label_1),
        (x1_2, x2_2, label_2),
        (x1_3, x2_3, label_3)
        ...
        ]

モデルの__call__で複数入力受けるように(上の例で言えばx1とx2)していれば、受け付けてくれマス