機器學習筆記(七)--理解batch_dot函數

在keras中有batch_dot函數,用於計算兩個多維矩陣,官方註釋如下:

def batch_dot(x, y, axes=None):
    """Batchwise dot product.

    `batch_dot` is used to compute dot product of `x` and `y` when
    `x` and `y` are data in batches, i.e. in a shape of
    `(batch_size, :)`.
    `batch_dot` results in a tensor or variable with less dimensions
    than the input. If the number of dimensions is reduced to 1,
    we use `expand_dims` to make sure that ndim is at least 2.

    這個函數是用於計算批次數據‘x’和‘y'的內積,兩個數據的batch_size必須相同。
    函數的輸出張量的維度數量會少於輸入的維度數量和。如果輸出的維度數量減少到1,就會使用
    ’expand_dim‘函數來確保維度數量至少爲2。

    # Arguments
        x: Keras tensor or variable with `ndim >= 2`.  維度數量 >= 2
        y: Keras tensor or variable with `ndim >= 2`.  維度數量 >= 2
        axes: int or tuple(int, int). Target dimensions to be reduced.  
        要減少的目標維度。理論上從0開始(即shape首位),但batch_size是忽略的,故從1開始。若是一
        個整數,則表示兩個輸入的shape的同一位。若是一個tuple或list,則分別指向不同位置。
        注意:無論axes是那種類型,指向的兩個位置上的數值必須一致。

    # Returns
        A tensor with shape equal to the concatenation of `x`'s shape
        (less the dimension that was summed over) and `y`'s shape
        (less the batch dimension and the dimension that was summed over).
        If the final rank is 1, we reshape it to `(batch_size, 1)`.
        
    """

下面對例子進行說明。

 >>> x_batch = K.ones(shape=(32, 20, 1))
 >>> y_batch = K.ones(shape=(32, 30, 20))
 >>> xy_batch_dot = K.batch_dot(x_batch, y_batch, axes=(1, 2))
 >>> K.int_shape(xy_batch_dot)
 (32, 1, 30)

首先我認爲,該函數進行還是普通的矩陣乘法,但是兩個輸入矩陣的格式明顯不符合,所以進行了類似reshape的操作,具體就是將左邊的矩陣的目標位移動到末尾,將右邊矩陣的目標位移動到首位。如上例,去掉batch_size,原規格爲(20,1)和(30,20),因爲axes=(1,2),故變爲(1,20)和(20,30)的矩陣乘法,結果爲(1,30),加上batch_size即爲(32,1,30)。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章