mirror of https://github.com/llvm/torch-mlir
[onnx] Lowerings from `onnx.transpose` (#2641)
Lowerings for `transpose` from ONNX to `aten`. Implementation depends on making multiple `aten.transpose` operations swapping pairs of dimensions. As `onnx.transpose` can swap around any dimensions it may require constructing multiple `aten.transpose`.pull/2660/head
parent
030b0140d4
commit
705ea958ae
|
@ -146,6 +146,31 @@ struct OpBinder {
|
|||
return failure();
|
||||
}
|
||||
|
||||
ParseResult s64IntegerArrayAttr(llvm::SmallVector<int64_t> &values,
|
||||
StringRef nameSuffix,
|
||||
ArrayRef<int64_t> defaults) {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
name.append(nameSuffix);
|
||||
auto attr = op->getAttr(name);
|
||||
if (!attr) {
|
||||
values.append(defaults.begin(), defaults.end());
|
||||
return success();
|
||||
}
|
||||
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||
for (auto element : arrayAttr) {
|
||||
auto integerAttr = element.dyn_cast<IntegerAttr>();
|
||||
if (!integerAttr)
|
||||
return failure();
|
||||
IntegerType t = cast<IntegerType>(integerAttr.getType());
|
||||
if (!t.isSigned() || t.getWidth() != 64)
|
||||
return failure();
|
||||
values.push_back(integerAttr.getSInt());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
|
||||
std::string defaultValue = "") {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
|
|
|
@ -472,4 +472,73 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"Transpose", 13,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
auto loc = binder.getLoc();
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
auto operandType = operand.getType().cast<Torch::ValueTensorType>();
|
||||
TensorType tensorType = operandType.toBuiltinTensor();
|
||||
if (!tensorType || !tensorType.hasRank())
|
||||
return failure();
|
||||
|
||||
// Default permutation is to reverse orders:
|
||||
int64_t rank = tensorType.getRank();
|
||||
llvm::SmallVector<int64_t> reverse(rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
reverse[i] = rank - i - 1;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> permutations;
|
||||
if (failed(binder.s64IntegerArrayAttr(permutations, "perm", reverse)))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Failed to obtain permutations");
|
||||
|
||||
if (static_cast<int64_t>(permutations.size()) != rank)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Permutation length does not match operand rank");
|
||||
|
||||
llvm::SmallVector<int64_t> shape(tensorType.getShape());
|
||||
llvm::SmallVector<int64_t> current(rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
current[i] = i;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
if (current[i] == permutations[i])
|
||||
continue;
|
||||
|
||||
int64_t target = i + 1;
|
||||
for (; target < rank; ++target) {
|
||||
if (current[target] == permutations[i])
|
||||
break;
|
||||
}
|
||||
|
||||
std::swap(shape[i], shape[target]);
|
||||
std::swap(current[i], current[target]);
|
||||
|
||||
Value dim0 = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
|
||||
Value dim1 = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), target));
|
||||
|
||||
operand = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc,
|
||||
Torch::ValueTensorType::get(tensorType.getContext(), shape,
|
||||
operandType.getOptionalDtype()),
|
||||
operand, dim0, dim1);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(binder.op, operand);
|
||||
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -475,6 +475,8 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_selu
|
||||
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
|
||||
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
|
||||
|
@ -484,3 +486,32 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
|||
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_transpose_default
|
||||
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[TRANSPOSE:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32>
|
||||
%0 = torch.operator "onnx.Transpose"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32>
|
||||
|
||||
// CHECK: return %[[TRANSPOSE]]
|
||||
return %0 : !torch.vtensor<[4,3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_transpose_all_permutations_4
|
||||
func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[TRANSPOSE0:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32>
|
||||
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[TRANSPOSE1:.+]] = torch.aten.transpose.int %[[TRANSPOSE0]], %[[I1]], %[[I2]] : !torch.vtensor<[4,3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,2,3],f32>
|
||||
%0 = torch.operator "onnx.Transpose"(%arg0) {torch.onnx.perm = [2 : si64, 0 : si64, 1 : si64]} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32>
|
||||
|
||||
// CHECK: return %[[TRANSPOSE1]]
|
||||
return %0 : !torch.vtensor<[4,2,3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue