mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx.RNN` for layout attribute (#3620)
The `layout` attribute was not considered for the `onnx.RNN` operation. Added support for the attribute to transpose the inputs / outputs of the RNN when valid.pull/3631/head
parent
af67f9efb0
commit
2511cf46b4
|
@ -151,6 +151,22 @@ RnnLayerOutput rnn_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
|
|||
output.Y_h = loop.getResult(1);
|
||||
return output;
|
||||
}
|
||||
|
||||
static Value StaticTranspose(ImplicitLocOpBuilder b, Value value, int64_t dim0,
|
||||
int64_t dim1) {
|
||||
auto valueTy = cast<ValueTensorType>(value.getType());
|
||||
|
||||
SmallVector<int64_t> valueShape(valueTy.getSizes());
|
||||
std::swap(valueShape[dim0], valueShape[dim1]);
|
||||
valueTy = b.getType<ValueTensorType>(valueShape, valueTy.getDtype());
|
||||
|
||||
auto intType = b.getType<IntType>();
|
||||
Value dim0v = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(dim0));
|
||||
Value dim1v = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(dim1));
|
||||
|
||||
return b.create<AtenTransposeIntOp>(valueTy, value, dim0v, dim1v);
|
||||
}
|
||||
|
||||
LogicalResult OnnxRnnExpander(OpBinder binder,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Location loc = binder.getLoc();
|
||||
|
@ -201,9 +217,19 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
|
|||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Missing required attribute hidden_size");
|
||||
|
||||
// Other attributes
|
||||
int64_t layout;
|
||||
if (binder.s64IntegerAttr(layout, "layout", 0))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unsupported layout attribute type.");
|
||||
|
||||
if (layout < 0 || layout > 1)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unsupported layout attribute value.");
|
||||
|
||||
// Result types
|
||||
ValueTensorType yTy, Y_hType;
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
|
||||
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"At least one output must be present");
|
||||
|
@ -229,6 +255,12 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
|
|||
initial_h = nullptr;
|
||||
}
|
||||
|
||||
if (layout == 1) {
|
||||
X = StaticTranspose(b, X, 0, 1);
|
||||
if (initial_h)
|
||||
initial_h = StaticTranspose(b, initial_h, 0, 1);
|
||||
}
|
||||
|
||||
// validation
|
||||
auto xTy = cast<ValueTensorType>(X.getType());
|
||||
auto wTy = cast<ValueTensorType>(W.getType());
|
||||
|
@ -238,6 +270,7 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
|
|||
auto rShape = rTy.getSizes();
|
||||
assert(wShape.size() == 3);
|
||||
|
||||
int64_t seq_len = xShape[0];
|
||||
int64_t batch_size = xShape[1];
|
||||
int64_t x_input_size = xShape[2];
|
||||
|
||||
|
@ -368,7 +401,24 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
|
|||
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(Y_h_unsqueezed_type,
|
||||
rnnLayerOutput.Y_h, cstZero);
|
||||
|
||||
Value Y_unsqueezed = b.create<AtenUnsqueezeOp>(yTy, rnnLayerOutput.Y, cstOne);
|
||||
auto Y_unsqueezed_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
|
||||
hidden_size},
|
||||
cast<ValueTensorType>(rnnLayerOutput.Y_h.getType()).getDtype());
|
||||
Value Y_unsqueezed =
|
||||
b.create<AtenUnsqueezeOp>(Y_unsqueezed_type, rnnLayerOutput.Y, cstOne);
|
||||
|
||||
if (layout == 1) {
|
||||
Y_h_unsqueezed = StaticTranspose(b, Y_h_unsqueezed, 0, 1);
|
||||
Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 1, 2);
|
||||
Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 0, 1);
|
||||
}
|
||||
|
||||
if (!yTy)
|
||||
Y_unsqueezed = cstNone;
|
||||
if (!Y_hType)
|
||||
Y_h_unsqueezed = cstNone;
|
||||
|
||||
rewriter.replaceOp(binder.op, {Y_unsqueezed, Y_h_unsqueezed});
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue