# createst uniform random array w/ values in [a,b) and shape args defrand_arr(a, b, *args): np.random.seed(0) return np.random.rand(*args) * (b - a) + a
classLstmNode: def__init__(self, lstm_param, lstm_state): # store reference to parameters and to activations self.state = lstm_state self.param = lstm_param # non-recurrent input concatenated with recurrent input self.xc = None
LSTM只有一个节点,该节点包含训练参数lstm_param和状态参数lstm_state
训练参数是模型的灵魂所在,网络学习的过程就是训练参数不断调整的过程,最好不要人为干预
状态参数可以根据不同的输入通过与训练参数计算而来
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
defbottom_data_is(self, x, s_prev = None, h_prev = None): # if this is the first lstm node in the network if s_prev == None: s_prev = np.zeros_like(self.state.s) if h_prev == None: h_prev = np.zeros_like(self.state.h) # save data for use in backprop self.s_prev = s_prev self.h_prev = h_prev
defy_list_is(self, y_list, loss_layer): """ Updates diffs by setting target sequence with corresponding loss layer. Will *NOT* update parameters. To update parameters, call self.lstm_param.apply_diff() """ assert len(y_list) == len(self.x_list) idx = len(self.x_list) - 1 # first node only gets diffs from label ... loss = loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) # here s is not affecting loss due to h(t+1), hence we set equal to zero diff_s = np.zeros(self.lstm_param.mem_cell_ct) self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1
### ... following nodes also get diffs from next nodes, hence we add diffs to diff_h ### we also propagate error along constant error carousel using diff_s while idx >= 0: loss += loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h += self.lstm_node_list[idx + 1].state.bottom_diff_h diff_s = self.lstm_node_list[idx + 1].state.bottom_diff_s self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1
defx_list_add(self, x): self.x_list.append(x) if len(self.x_list) > len(self.lstm_node_list): # need to add new lstm node, create new state mem lstm_state = LstmState(self.lstm_param.mem_cell_ct, self.lstm_param.x_dim) self.lstm_node_list.append(LstmNode(self.lstm_param, lstm_state))
# get index of most recent x input idx = len(self.x_list) - 1 if idx == 0: # no recurrent inputs yet self.lstm_node_list[idx].bottom_data_is(x) else: s_prev = self.lstm_node_list[idx - 1].state.s h_prev = self.lstm_node_list[idx - 1].state.h self.lstm_node_list[idx].bottom_data_is(x, s_prev, h_prev)