Infer the shapes of the output of a function based on the shapes of the input. This is done as follows:
All
NA
s are replaced with values1
,2
,3
.Three tensors are generated for the three shapes of step 1.
The function is called on these three tensors and the shapes are calculated.
If:
the number of dimensions varies, an error is thrown.
the number of dimensions is the same, values are set to
NA
if the dimension is varying between the three tensors and otherwise set to the unique value.
Arguments
- shapes_in
(
list()
)
A list of shapes of the input tensors.- param_vals
(
list()
)
A list of named parameters for the function.- output_names
(
character()
)
The names of the output tensors.- fn
(
function()
)
The function to infer the shapes for.- rowwise
(
logical(1)
)
Whether the function is rowwise.- id
(
character(1)
)
The id of the PipeOp (for error messages).
Value
(list()
)
A list of shapes of the output tensors.