Chainerで複数入力のモデルとデータセットの作り方
Chainer使ってて複数の入力をするモデルを作る必要があって、解決したのでメモ
作りたかったモデルはこんなの
class LSTM_dual_input(Chain): def __init__(self): super(LSTM_dual_input, self).__init__() with self.init_scope(): self.emb = L.EmbedID(10000, 300, ignore_label=-1) self.lstm = L.LSTM(None, 100) self.linear = L.Linear(None, 100) self.linear2 = L.Linear(None, 2000) def reset_state(self): self.lstm.reset_state() def __call__(self, x1,x2): x1_t = F.transpose_sequence(x1) for x1_i in x1_t: h1 = self.emb(x1_i) h2 = self.lstm(h1) h3=F.concat([h2, F.relu(self.linear(x2))], axis=1) y = self.linear2(h3) return y
時系列の入力x1をLSTMでエンコードして、もう一つの入力x2は専用の全結合層に普通に通す。
次の層でLSTMと全結合層と合わせて(concatの部分)、出力層につないでクラス分類するよって構造。
まぁモデルは別にいいんだけど、悩んだのは複数入力する場合のデータの作り方
たぶん簡単にやるならdatasetを次のように作るはず
dataset = [ (x1_1, label_1), (x1_2, label_2), (x1_3, label_3) ... ]
じゃあこれを複数にしよう!って思って俺がやったのがこれ
dataset = [ ((x1_1, x2_1), label_1), ((x1_2, x2_2), label_2), ((x1_3, x2_3), label_3) ... ]
てっきり(データ、ラベル)の形式しかできないんじゃないかと勝手に思い込んでました.....普通に(データ、データ、ラベル)っていけるのね........
正解の複数型のデータセットは以下
dataset = [ (x1_1, x2_1, label_1), (x1_2, x2_2, label_2), (x1_3, x2_3, label_3) ... ]
モデルの__call__で複数入力受けるように(上の例で言えばx1とx2)していれば、受け付けてくれマス