mirror of https://github.com/llvm/torch-mlir
[uncategorized_lowerings] Add lowering for torch.aten.round.decimals
Implement missing lowering for the op in a similar fashion as done by torch inductor. Also fix data movement and reduce op variants patterns to correctly handle explicitly declared legal ops. Signed-off-by: Prathamesh Tagore <prathamesh+1@polymagelabs.com>pull/3811/head
parent
140cad5659
commit
0f5a2dc844
|
@ -109,7 +109,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createReduceOpVariantsPass(StringRef extraLibrary);
|
createReduceOpVariantsPass(StringRef extraLibrary, ArrayRef<std::string> = {});
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
|
||||||
|
|
||||||
|
|
|
@ -148,6 +148,9 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
|
||||||
let options = [
|
let options = [
|
||||||
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||||
"MLIR module for verifying custom op value semantics">,
|
"MLIR module for verifying custom op value semantics">,
|
||||||
|
ListOption<"legalOps", "legal-ops", "std::string",
|
||||||
|
"Comma separated list of operation names that should be considered legal",
|
||||||
|
"llvm::cl::ZeroOrMore">
|
||||||
];
|
];
|
||||||
let description = [{
|
let description = [{
|
||||||
Replaces ops with other ops to reduce the number of variants that
|
Replaces ops with other ops to reduce the number of variants that
|
||||||
|
|
|
@ -2874,7 +2874,10 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
|
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
|
||||||
// Rewrite all special sparse conversions hidden as operators.
|
// Rewrite all special sparse conversions hidden as operators.
|
||||||
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
|
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
|
||||||
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
|
// Note: Legality behaviour of torch.operator ops that are not sparse
|
||||||
|
// primitives should be conserved and not modified by this block.
|
||||||
|
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr()) &&
|
||||||
|
typeConverter.isLegal(op);
|
||||||
});
|
});
|
||||||
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
|
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1557,6 +1557,43 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
threshold);
|
threshold);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto operatorOp = dyn_cast<OperatorOp>(op)) {
|
||||||
|
// We do not yet implement lowering for other variants of the op.
|
||||||
|
if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals")
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// Lower the op in a similar fashion as described here:
|
||||||
|
// https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/_inductor/decomposition.py#L223.
|
||||||
|
// Note that `aten.round` is converted to `math.roundeven`, we do this
|
||||||
|
// implicitly here because `aten.round` cannot operate on a single input
|
||||||
|
// tensor element which is what we get as payload argument.
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Type i64Type = b.getI64Type();
|
||||||
|
|
||||||
|
auto torchIntOp = dyn_cast<ConstantIntOp>(
|
||||||
|
operatorOp.getOperands().back().getDefiningOp());
|
||||||
|
if (!torchIntOp)
|
||||||
|
return nullptr;
|
||||||
|
int64_t numDecimalsArg = torchIntOp.getValue();
|
||||||
|
|
||||||
|
Value inputTensorElem = payloadArgs[0];
|
||||||
|
Type inputTensorElemType = inputTensorElem.getType();
|
||||||
|
|
||||||
|
auto numDecimals = b.create<arith::ConstantOp>(
|
||||||
|
loc, i64Type, IntegerAttr::get(i64Type, numDecimalsArg));
|
||||||
|
auto const10 = b.create<arith::ConstantOp>(
|
||||||
|
loc, inputTensorElemType, FloatAttr::get(inputTensorElemType, 10));
|
||||||
|
auto tenPowDecimals = b.create<math::FPowIOp>(loc, const10, numDecimals);
|
||||||
|
|
||||||
|
auto mulTenPowDecimalsinputTensorElem =
|
||||||
|
b.create<arith::MulFOp>(loc, inputTensorElem, tenPowDecimals);
|
||||||
|
auto roundOp =
|
||||||
|
b.create<math::RoundEvenOp>(loc, mulTenPowDecimalsinputTensorElem);
|
||||||
|
auto res = b.create<arith::DivFOp>(loc, roundOp, tenPowDecimals);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
op->emitError("unimplemented lowering in "
|
op->emitError("unimplemented lowering in "
|
||||||
"createLinalgPayloadCalculationForElementwiseOp");
|
"createLinalgPayloadCalculationForElementwiseOp");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1616,9 +1653,14 @@ public:
|
||||||
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
||||||
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
||||||
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||||
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
|
if (auto operatorOp = dyn_cast<OperatorOp>(op))
|
||||||
|
if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals")
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -3375,7 +3417,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
|
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
|
||||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||||
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||||
AtenQuantizePerTensorOp, AtenIscloseOp>();
|
AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -70,8 +70,8 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
|
||||||
|
|
||||||
void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
|
void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
pm.addNestedPass<func::FuncOp>(
|
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass(
|
||||||
createReduceOpVariantsPass(options.extraLibrary));
|
options.extraLibrary, options.backendLegalOps));
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
if (options.decompose) {
|
if (options.decompose) {
|
||||||
pm.addNestedPass<func::FuncOp>(
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
@ -161,8 +161,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
|
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
|
||||||
// Reduce variants of ops to a smaller set of primitives.
|
// Reduce variants of ops to a smaller set of primitives.
|
||||||
pm.addNestedPass<func::FuncOp>(
|
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass(
|
||||||
createReduceOpVariantsPass(options.extraLibrary));
|
options.extraLibrary, options.backendLegalOps));
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
// Remove dead global slots.
|
// Remove dead global slots.
|
||||||
pm.addPass(createSymbolDCEPass());
|
pm.addPass(createSymbolDCEPass());
|
||||||
|
|
|
@ -403,8 +403,9 @@ namespace {
|
||||||
struct ReduceOpVariantsPass
|
struct ReduceOpVariantsPass
|
||||||
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
ReduceOpVariantsPass() = default;
|
ReduceOpVariantsPass() = default;
|
||||||
ReduceOpVariantsPass(StringRef extraLibrary) {
|
ReduceOpVariantsPass(StringRef extraLibrary, ArrayRef<std::string> legalOps) {
|
||||||
this->extraLibrary = extraLibrary.str();
|
this->extraLibrary = extraLibrary.str();
|
||||||
|
this->legalOps = legalOps;
|
||||||
}
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
|
@ -439,13 +440,15 @@ struct ReduceOpVariantsPass
|
||||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||||
target.addIllegalOp<AtenArangeStartOutOp>();
|
target.addIllegalOp<AtenArangeStartOutOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable,
|
|
||||||
&specializedNames](Operation *op) {
|
target.addDynamicallyLegalOp<OperatorOp>([&](OperatorOp op) {
|
||||||
if (isa<OperatorOp>(op)) {
|
auto opNameAttr = op.getNameAttr();
|
||||||
if (specializedNames.contains(cast<OperatorOp>(op).getNameAttr())) {
|
return llvm::find(legalOps, opNameAttr.str()) != legalOps.end() &&
|
||||||
return false;
|
!specializedNames.contains(opNameAttr);
|
||||||
}
|
});
|
||||||
}
|
|
||||||
|
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
||||||
|
Operation *op) {
|
||||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
||||||
(isa<OperatorOp>(op) &&
|
(isa<OperatorOp>(op) &&
|
||||||
operatorOpHasValueSemantics(cast<OperatorOp>(op),
|
operatorOpHasValueSemantics(cast<OperatorOp>(op),
|
||||||
|
@ -479,6 +482,7 @@ struct ReduceOpVariantsPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) {
|
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary,
|
||||||
return std::make_unique<ReduceOpVariantsPass>(extraLibrary);
|
ArrayRef<std::string> legalOps) {
|
||||||
|
return std::make_unique<ReduceOpVariantsPass>(extraLibrary, legalOps);
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,3 +102,36 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
|
||||||
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
return %0 : !torch.vtensor<[3],f32>
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch_aten_round_decimals
|
||||||
|
// CHECK: %[[VAL2:.*]] = linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
|
||||||
|
// CHECK-NEXT: %[[CONST_64:.*]] = arith.constant
|
||||||
|
// CHECK-NEXT: %[[CONST_10:.*]] = arith.constant 1.000000e+01
|
||||||
|
// CHECK-NEXT: %[[VAL4:.*]] = math.fpowi %[[CONST_10]], %[[CONST_64]]
|
||||||
|
// CHECK-NEXT: %[[VAL5:.*]] = arith.mulf %[[IN]], %[[VAL4]]
|
||||||
|
// CHECK-NEXT: %[[VAL6:.*]] = math.roundeven %[[VAL5]]
|
||||||
|
// CHECK-NEXT: %[[VAL7:.*]] = arith.divf %[[VAL6]], %[[VAL4]]
|
||||||
|
// CHECK-NEXT: linalg.yield %[[VAL7]]
|
||||||
|
// CHECK: %[[CAST:.*]] = tensor.cast %[[VAL2]]
|
||||||
|
// CHECK-NEXT: %[[VAL3:.*]] = torch_c.from_builtin_tensor %[[CAST]]
|
||||||
|
// CHECK-NEXT: return %[[VAL3]]
|
||||||
|
func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32>
|
||||||
|
return %1 : !torch.vtensor<[1, 1024,1024,3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test that unhandled versions of `torch.operator` op are not legalized.
|
||||||
|
func.func @torch.prims.device_put(%arg0: !torch.vtensor<[1,77],si64>) -> !torch.vtensor<[1,77],si64> {
|
||||||
|
%cuda3A0 = torch.constant.device "cuda:0"
|
||||||
|
// expected-error @+1 {{failed to legalize operation 'torch.operator' that was explicitly marked illegal}}
|
||||||
|
%0 = torch.operator "torch.prims.device_put"(%arg0, %cuda3A0) : (!torch.vtensor<[1,77],si64>, !torch.Device) -> !torch.vtensor<[1,77],si64>
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%1 = torch.prims.convert_element_type %0, %int4 : !torch.vtensor<[1,77],si64>, !torch.int -> !torch.vtensor<[1,77],si64>
|
||||||
|
return %1 : !torch.vtensor<[1,77],si64>
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax,torch.aten.round.decimals})' -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.square
|
// CHECK-LABEL: func.func @torch.aten.square
|
||||||
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
@ -25,3 +25,11 @@ func.func @torch.uint8(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[
|
||||||
%1 = torch.aten.reshape %arg0, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
|
%1 = torch.aten.reshape %arg0, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
|
||||||
return %1 : !torch.tensor
|
return %1 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that "torch.aten.round.decimals" was considered legal after explicitly specifying it in pass options.
|
||||||
|
// CHECK-LABEL: func.func @torch_aten_round_decimals
|
||||||
|
func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32>
|
||||||
|
return %1 : !torch.vtensor<[1, 1024,1024,3],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue