File size: 965 Bytes
899c526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import numpy as np

def check_broadcastable(x, y):
    assert len(x.shape) == len(y.shape)
    for (n, m) in zip(x.shape[:-1], y.shape[:-1]):
        assert n==m or n==1 or m==1

def broadcast_inputs(x, y):
    """ Automatic broadcasting of missing dimensions """
    if y is None:
        xs, xd = x.shape[:-1], x.shape[-1] 
        return (x.view(-1, xd).contiguous(), ), x.shape[:-1]

    check_broadcastable(x, y)

    xs, xd = x.shape[:-1], x.shape[-1] 
    ys, yd = y.shape[:-1], y.shape[-1]
    out_shape = [max(n,m) for (n,m) in zip(xs,ys)]

    if x.shape[:-1] == y.shape[-1]:
        x1 = x.view(-1, xd)
        y1 = y.view(-1, yd)

    else:
        x_expand = [m if n==1 else 1 for (n,m) in zip(xs, ys)]
        y_expand = [n if m==1 else 1 for (n,m) in zip(xs, ys)]
        x1 = x.repeat(x_expand + [1]).reshape(-1, xd).contiguous()
        y1 = y.repeat(y_expand + [1]).reshape(-1, yd).contiguous()

    return (x1, y1), tuple(out_shape)