mirror of https://github.com/llvm/torch-mlir
[onnx] Lowerings from `onnx.transpose` (#2641)
Lowerings for `transpose` from ONNX to `aten`. Implementation depends on making multiple `aten.transpose` operations swapping pairs of dimensions. As `onnx.transpose` can swap around any dimensions it may require constructing multiple `aten.transpose`.pull/2660/head
parent
030b0140d4
commit
705ea958ae
|
@ -146,6 +146,31 @@ struct OpBinder {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult s64IntegerArrayAttr(llvm::SmallVector<int64_t> &values,
|
||||||
|
StringRef nameSuffix,
|
||||||
|
ArrayRef<int64_t> defaults) {
|
||||||
|
SmallString<64> name("torch.onnx.");
|
||||||
|
name.append(nameSuffix);
|
||||||
|
auto attr = op->getAttr(name);
|
||||||
|
if (!attr) {
|
||||||
|
values.append(defaults.begin(), defaults.end());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||||
|
for (auto element : arrayAttr) {
|
||||||
|
auto integerAttr = element.dyn_cast<IntegerAttr>();
|
||||||
|
if (!integerAttr)
|
||||||
|
return failure();
|
||||||
|
IntegerType t = cast<IntegerType>(integerAttr.getType());
|
||||||
|
if (!t.isSigned() || t.getWidth() != 64)
|
||||||
|
return failure();
|
||||||
|
values.push_back(integerAttr.getSInt());
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
|
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
|
||||||
std::string defaultValue = "") {
|
std::string defaultValue = "") {
|
||||||
SmallString<64> name("torch.onnx.");
|
SmallString<64> name("torch.onnx.");
|
||||||
|
|
|
@ -472,4 +472,73 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, resultType, operand);
|
binder.op, resultType, operand);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
patterns.onOp(
|
||||||
|
"Transpose", 13,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
auto loc = binder.getLoc();
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value operand;
|
||||||
|
if (binder.tensorOperand(operand) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
auto operandType = operand.getType().cast<Torch::ValueTensorType>();
|
||||||
|
TensorType tensorType = operandType.toBuiltinTensor();
|
||||||
|
if (!tensorType || !tensorType.hasRank())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Default permutation is to reverse orders:
|
||||||
|
int64_t rank = tensorType.getRank();
|
||||||
|
llvm::SmallVector<int64_t> reverse(rank);
|
||||||
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
|
reverse[i] = rank - i - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> permutations;
|
||||||
|
if (failed(binder.s64IntegerArrayAttr(permutations, "perm", reverse)))
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Failed to obtain permutations");
|
||||||
|
|
||||||
|
if (static_cast<int64_t>(permutations.size()) != rank)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Permutation length does not match operand rank");
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> shape(tensorType.getShape());
|
||||||
|
llvm::SmallVector<int64_t> current(rank);
|
||||||
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
|
current[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
|
if (current[i] == permutations[i])
|
||||||
|
continue;
|
||||||
|
|
||||||
|
int64_t target = i + 1;
|
||||||
|
for (; target < rank; ++target) {
|
||||||
|
if (current[target] == permutations[i])
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::swap(shape[i], shape[target]);
|
||||||
|
std::swap(current[i], current[target]);
|
||||||
|
|
||||||
|
Value dim0 = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
|
||||||
|
Value dim1 = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), target));
|
||||||
|
|
||||||
|
operand = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||||
|
loc,
|
||||||
|
Torch::ValueTensorType::get(tensorType.getContext(), shape,
|
||||||
|
operandType.getOptionalDtype()),
|
||||||
|
operand, dim0, dim1);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(binder.op, operand);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -475,6 +475,8 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to
|
||||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_selu
|
// CHECK-LABEL: func.func @test_selu
|
||||||
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
|
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
|
||||||
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
|
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
|
||||||
|
@ -484,3 +486,32 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
||||||
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_transpose_default
|
||||||
|
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||||
|
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[TRANSPOSE:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32>
|
||||||
|
%0 = torch.operator "onnx.Transpose"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32>
|
||||||
|
|
||||||
|
// CHECK: return %[[TRANSPOSE]]
|
||||||
|
return %0 : !torch.vtensor<[4,3,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_transpose_all_permutations_4
|
||||||
|
func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||||
|
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[TRANSPOSE0:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32>
|
||||||
|
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[TRANSPOSE1:.+]] = torch.aten.transpose.int %[[TRANSPOSE0]], %[[I1]], %[[I2]] : !torch.vtensor<[4,3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,2,3],f32>
|
||||||
|
%0 = torch.operator "onnx.Transpose"(%arg0) {torch.onnx.perm = [2 : si64, 0 : si64, 1 : si64]} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32>
|
||||||
|
|
||||||
|
// CHECK: return %[[TRANSPOSE1]]
|
||||||
|
return %0 : !torch.vtensor<[4,2,3],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue