mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx.RNN` for layout attribute (#3620)
The `layout` attribute was not considered for the `onnx.RNN` operation. Added support for the attribute to transpose the inputs / outputs of the RNN when valid.pull/3631/head
parent
af67f9efb0
commit
2511cf46b4
|
@ -151,6 +151,22 @@ RnnLayerOutput rnn_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
|
||||||
output.Y_h = loop.getResult(1);
|
output.Y_h = loop.getResult(1);
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value StaticTranspose(ImplicitLocOpBuilder b, Value value, int64_t dim0,
|
||||||
|
int64_t dim1) {
|
||||||
|
auto valueTy = cast<ValueTensorType>(value.getType());
|
||||||
|
|
||||||
|
SmallVector<int64_t> valueShape(valueTy.getSizes());
|
||||||
|
std::swap(valueShape[dim0], valueShape[dim1]);
|
||||||
|
valueTy = b.getType<ValueTensorType>(valueShape, valueTy.getDtype());
|
||||||
|
|
||||||
|
auto intType = b.getType<IntType>();
|
||||||
|
Value dim0v = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(dim0));
|
||||||
|
Value dim1v = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(dim1));
|
||||||
|
|
||||||
|
return b.create<AtenTransposeIntOp>(valueTy, value, dim0v, dim1v);
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult OnnxRnnExpander(OpBinder binder,
|
LogicalResult OnnxRnnExpander(OpBinder binder,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
Location loc = binder.getLoc();
|
Location loc = binder.getLoc();
|
||||||
|
@ -201,9 +217,19 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Missing required attribute hidden_size");
|
binder.op, "Missing required attribute hidden_size");
|
||||||
|
|
||||||
|
// Other attributes
|
||||||
|
int64_t layout;
|
||||||
|
if (binder.s64IntegerAttr(layout, "layout", 0))
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Unsupported layout attribute type.");
|
||||||
|
|
||||||
|
if (layout < 0 || layout > 1)
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Unsupported layout attribute value.");
|
||||||
|
|
||||||
// Result types
|
// Result types
|
||||||
ValueTensorType yTy, Y_hType;
|
ValueTensorType yTy, Y_hType;
|
||||||
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
|
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
|
||||||
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"At least one output must be present");
|
"At least one output must be present");
|
||||||
|
@ -229,6 +255,12 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
|
||||||
initial_h = nullptr;
|
initial_h = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (layout == 1) {
|
||||||
|
X = StaticTranspose(b, X, 0, 1);
|
||||||
|
if (initial_h)
|
||||||
|
initial_h = StaticTranspose(b, initial_h, 0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
// validation
|
// validation
|
||||||
auto xTy = cast<ValueTensorType>(X.getType());
|
auto xTy = cast<ValueTensorType>(X.getType());
|
||||||
auto wTy = cast<ValueTensorType>(W.getType());
|
auto wTy = cast<ValueTensorType>(W.getType());
|
||||||
|
@ -238,6 +270,7 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
|
||||||
auto rShape = rTy.getSizes();
|
auto rShape = rTy.getSizes();
|
||||||
assert(wShape.size() == 3);
|
assert(wShape.size() == 3);
|
||||||
|
|
||||||
|
int64_t seq_len = xShape[0];
|
||||||
int64_t batch_size = xShape[1];
|
int64_t batch_size = xShape[1];
|
||||||
int64_t x_input_size = xShape[2];
|
int64_t x_input_size = xShape[2];
|
||||||
|
|
||||||
|
@ -368,7 +401,24 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
|
||||||
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(Y_h_unsqueezed_type,
|
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(Y_h_unsqueezed_type,
|
||||||
rnnLayerOutput.Y_h, cstZero);
|
rnnLayerOutput.Y_h, cstZero);
|
||||||
|
|
||||||
Value Y_unsqueezed = b.create<AtenUnsqueezeOp>(yTy, rnnLayerOutput.Y, cstOne);
|
auto Y_unsqueezed_type = b.getType<ValueTensorType>(
|
||||||
|
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
|
||||||
|
hidden_size},
|
||||||
|
cast<ValueTensorType>(rnnLayerOutput.Y_h.getType()).getDtype());
|
||||||
|
Value Y_unsqueezed =
|
||||||
|
b.create<AtenUnsqueezeOp>(Y_unsqueezed_type, rnnLayerOutput.Y, cstOne);
|
||||||
|
|
||||||
|
if (layout == 1) {
|
||||||
|
Y_h_unsqueezed = StaticTranspose(b, Y_h_unsqueezed, 0, 1);
|
||||||
|
Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 1, 2);
|
||||||
|
Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!yTy)
|
||||||
|
Y_unsqueezed = cstNone;
|
||||||
|
if (!Y_hType)
|
||||||
|
Y_h_unsqueezed = cstNone;
|
||||||
|
|
||||||
rewriter.replaceOp(binder.op, {Y_unsqueezed, Y_h_unsqueezed});
|
rewriter.replaceOp(binder.op, {Y_unsqueezed, Y_h_unsqueezed});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue