[onnx] Fix `onnx.RNN` for layout attribute (#3620)

The `layout` attribute was not considered for the `onnx.RNN` operation.
Added support for the attribute to transpose the inputs / outputs of the
RNN when valid.
pull/3631/head
Rob Suderman 2024-08-13 14:34:25 -07:00 committed by GitHub
parent af67f9efb0
commit 2511cf46b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 52 additions and 2 deletions

View File

@ -151,6 +151,22 @@ RnnLayerOutput rnn_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
output.Y_h = loop.getResult(1); output.Y_h = loop.getResult(1);
return output; return output;
} }
static Value StaticTranspose(ImplicitLocOpBuilder b, Value value, int64_t dim0,
int64_t dim1) {
auto valueTy = cast<ValueTensorType>(value.getType());
SmallVector<int64_t> valueShape(valueTy.getSizes());
std::swap(valueShape[dim0], valueShape[dim1]);
valueTy = b.getType<ValueTensorType>(valueShape, valueTy.getDtype());
auto intType = b.getType<IntType>();
Value dim0v = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(dim0));
Value dim1v = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(dim1));
return b.create<AtenTransposeIntOp>(valueTy, value, dim0v, dim1v);
}
LogicalResult OnnxRnnExpander(OpBinder binder, LogicalResult OnnxRnnExpander(OpBinder binder,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc(); Location loc = binder.getLoc();
@ -201,9 +217,19 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Missing required attribute hidden_size"); binder.op, "Missing required attribute hidden_size");
// Other attributes
int64_t layout;
if (binder.s64IntegerAttr(layout, "layout", 0))
return rewriter.notifyMatchFailure(binder.op,
"Unsupported layout attribute type.");
if (layout < 0 || layout > 1)
return rewriter.notifyMatchFailure(binder.op,
"Unsupported layout attribute value.");
// Result types // Result types
ValueTensorType yTy, Y_hType; ValueTensorType yTy, Y_hType;
if (binder.tensorResultTypeAtIndex(yTy, 0) || if (binder.tensorResultTypeAtIndex(yTy, 0) &&
binder.tensorResultTypeAtIndex(Y_hType, 1)) { binder.tensorResultTypeAtIndex(Y_hType, 1)) {
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"At least one output must be present"); "At least one output must be present");
@ -229,6 +255,12 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
initial_h = nullptr; initial_h = nullptr;
} }
if (layout == 1) {
X = StaticTranspose(b, X, 0, 1);
if (initial_h)
initial_h = StaticTranspose(b, initial_h, 0, 1);
}
// validation // validation
auto xTy = cast<ValueTensorType>(X.getType()); auto xTy = cast<ValueTensorType>(X.getType());
auto wTy = cast<ValueTensorType>(W.getType()); auto wTy = cast<ValueTensorType>(W.getType());
@ -238,6 +270,7 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
auto rShape = rTy.getSizes(); auto rShape = rTy.getSizes();
assert(wShape.size() == 3); assert(wShape.size() == 3);
int64_t seq_len = xShape[0];
int64_t batch_size = xShape[1]; int64_t batch_size = xShape[1];
int64_t x_input_size = xShape[2]; int64_t x_input_size = xShape[2];
@ -368,7 +401,24 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(Y_h_unsqueezed_type, Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(Y_h_unsqueezed_type,
rnnLayerOutput.Y_h, cstZero); rnnLayerOutput.Y_h, cstZero);
Value Y_unsqueezed = b.create<AtenUnsqueezeOp>(yTy, rnnLayerOutput.Y, cstOne); auto Y_unsqueezed_type = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
hidden_size},
cast<ValueTensorType>(rnnLayerOutput.Y_h.getType()).getDtype());
Value Y_unsqueezed =
b.create<AtenUnsqueezeOp>(Y_unsqueezed_type, rnnLayerOutput.Y, cstOne);
if (layout == 1) {
Y_h_unsqueezed = StaticTranspose(b, Y_h_unsqueezed, 0, 1);
Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 1, 2);
Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 0, 1);
}
if (!yTy)
Y_unsqueezed = cstNone;
if (!Y_hType)
Y_h_unsqueezed = cstNone;
rewriter.replaceOp(binder.op, {Y_unsqueezed, Y_h_unsqueezed}); rewriter.replaceOp(binder.op, {Y_unsqueezed, Y_h_unsqueezed});
return success(); return success();
} }