[onnx] Lowerings from `onnx.transpose` (#2641)

Lowerings for `transpose` from ONNX to `aten`. Implementation depends on
making multiple `aten.transpose` operations swapping pairs of dimensions.
As `onnx.transpose` can swap around any dimensions it may require
constructing multiple `aten.transpose`.
pull/2660/head
Rob Suderman 2023-12-15 15:30:05 -08:00 committed by GitHub
parent 030b0140d4
commit 705ea958ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 125 additions and 0 deletions

View File

@ -146,6 +146,31 @@ struct OpBinder {
return failure();
}
ParseResult s64IntegerArrayAttr(llvm::SmallVector<int64_t> &values,
StringRef nameSuffix,
ArrayRef<int64_t> defaults) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
values.append(defaults.begin(), defaults.end());
return success();
}
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
auto integerAttr = element.dyn_cast<IntegerAttr>();
if (!integerAttr)
return failure();
IntegerType t = cast<IntegerType>(integerAttr.getType());
if (!t.isSigned() || t.getWidth() != 64)
return failure();
values.push_back(integerAttr.getSInt());
}
return success();
}
return failure();
}
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
std::string defaultValue = "") {
SmallString<64> name("torch.onnx.");

View File

@ -472,4 +472,73 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"Transpose", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
auto loc = binder.getLoc();
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
auto operandType = operand.getType().cast<Torch::ValueTensorType>();
TensorType tensorType = operandType.toBuiltinTensor();
if (!tensorType || !tensorType.hasRank())
return failure();
// Default permutation is to reverse orders:
int64_t rank = tensorType.getRank();
llvm::SmallVector<int64_t> reverse(rank);
for (int64_t i = 0; i < rank; ++i) {
reverse[i] = rank - i - 1;
}
llvm::SmallVector<int64_t> permutations;
if (failed(binder.s64IntegerArrayAttr(permutations, "perm", reverse)))
return rewriter.notifyMatchFailure(binder.op,
"Failed to obtain permutations");
if (static_cast<int64_t>(permutations.size()) != rank)
return rewriter.notifyMatchFailure(
binder.op, "Permutation length does not match operand rank");
llvm::SmallVector<int64_t> shape(tensorType.getShape());
llvm::SmallVector<int64_t> current(rank);
for (int64_t i = 0; i < rank; ++i) {
current[i] = i;
}
for (int64_t i = 0; i < rank; ++i) {
if (current[i] == permutations[i])
continue;
int64_t target = i + 1;
for (; target < rank; ++target) {
if (current[target] == permutations[i])
break;
}
std::swap(shape[i], shape[target]);
std::swap(current[i], current[target]);
Value dim0 = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value dim1 = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), target));
operand = rewriter.create<Torch::AtenTransposeIntOp>(
loc,
Torch::ValueTensorType::get(tensorType.getContext(), shape,
operandType.getOptionalDtype()),
operand, dim0, dim1);
}
rewriter.replaceOp(binder.op, operand);
return success();
});
}

View File

@ -475,6 +475,8 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_selu
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
@ -484,3 +486,32 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_transpose_default
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
// CHECK: %[[TRANSPOSE:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32>
%0 = torch.operator "onnx.Transpose"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32>
// CHECK: return %[[TRANSPOSE]]
return %0 : !torch.vtensor<[4,3,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_transpose_all_permutations_4
func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
// CHECK: %[[TRANSPOSE0:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32>
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
// CHECK: %[[TRANSPOSE1:.+]] = torch.aten.transpose.int %[[TRANSPOSE0]], %[[I1]], %[[I2]] : !torch.vtensor<[4,3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,2,3],f32>
%0 = torch.operator "onnx.Transpose"(%arg0) {torch.onnx.perm = [2 : si64, 0 : si64, 1 : si64]} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32>
// CHECK: return %[[TRANSPOSE1]]
return %0 : !torch.vtensor<[4,2,3],f32>
}