mirror of https://github.com/llvm/torch-mlir
[uncategorized_lowerings] Add lowering for torch.aten.round.decimals
Implement missing lowering for the op in a similar fashion as done by torch inductor. Also fix data movement and reduce op variants patterns to correctly handle explicitly declared legal ops. Signed-off-by: Prathamesh Tagore <prathamesh+1@polymagelabs.com>pull/3811/head
parent
140cad5659
commit
0f5a2dc844
|
@ -109,7 +109,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
|||
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createReduceOpVariantsPass(StringRef extraLibrary);
|
||||
createReduceOpVariantsPass(StringRef extraLibrary, ArrayRef<std::string> = {});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
|
||||
|
||||
|
|
|
@ -148,6 +148,9 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
|
|||
let options = [
|
||||
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||
"MLIR module for verifying custom op value semantics">,
|
||||
ListOption<"legalOps", "legal-ops", "std::string",
|
||||
"Comma separated list of operation names that should be considered legal",
|
||||
"llvm::cl::ZeroOrMore">
|
||||
];
|
||||
let description = [{
|
||||
Replaces ops with other ops to reduce the number of variants that
|
||||
|
|
|
@ -2874,7 +2874,10 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
|
||||
// Rewrite all special sparse conversions hidden as operators.
|
||||
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
|
||||
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
|
||||
// Note: Legality behaviour of torch.operator ops that are not sparse
|
||||
// primitives should be conserved and not modified by this block.
|
||||
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr()) &&
|
||||
typeConverter.isLegal(op);
|
||||
});
|
||||
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -1557,6 +1557,43 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
threshold);
|
||||
}
|
||||
|
||||
if (auto operatorOp = dyn_cast<OperatorOp>(op)) {
|
||||
// We do not yet implement lowering for other variants of the op.
|
||||
if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals")
|
||||
return nullptr;
|
||||
|
||||
// Lower the op in a similar fashion as described here:
|
||||
// https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/_inductor/decomposition.py#L223.
|
||||
// Note that `aten.round` is converted to `math.roundeven`, we do this
|
||||
// implicitly here because `aten.round` cannot operate on a single input
|
||||
// tensor element which is what we get as payload argument.
|
||||
|
||||
Location loc = op->getLoc();
|
||||
Type i64Type = b.getI64Type();
|
||||
|
||||
auto torchIntOp = dyn_cast<ConstantIntOp>(
|
||||
operatorOp.getOperands().back().getDefiningOp());
|
||||
if (!torchIntOp)
|
||||
return nullptr;
|
||||
int64_t numDecimalsArg = torchIntOp.getValue();
|
||||
|
||||
Value inputTensorElem = payloadArgs[0];
|
||||
Type inputTensorElemType = inputTensorElem.getType();
|
||||
|
||||
auto numDecimals = b.create<arith::ConstantOp>(
|
||||
loc, i64Type, IntegerAttr::get(i64Type, numDecimalsArg));
|
||||
auto const10 = b.create<arith::ConstantOp>(
|
||||
loc, inputTensorElemType, FloatAttr::get(inputTensorElemType, 10));
|
||||
auto tenPowDecimals = b.create<math::FPowIOp>(loc, const10, numDecimals);
|
||||
|
||||
auto mulTenPowDecimalsinputTensorElem =
|
||||
b.create<arith::MulFOp>(loc, inputTensorElem, tenPowDecimals);
|
||||
auto roundOp =
|
||||
b.create<math::RoundEvenOp>(loc, mulTenPowDecimalsinputTensorElem);
|
||||
auto res = b.create<arith::DivFOp>(loc, roundOp, tenPowDecimals);
|
||||
return res;
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForElementwiseOp");
|
||||
return nullptr;
|
||||
|
@ -1616,9 +1653,14 @@ public:
|
|||
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
||||
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
||||
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (auto operatorOp = dyn_cast<OperatorOp>(op))
|
||||
if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals")
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
|
@ -3375,7 +3417,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp>();
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -70,8 +70,8 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
|
|||
|
||||
void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
createReduceOpVariantsPass(options.extraLibrary));
|
||||
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass(
|
||||
options.extraLibrary, options.backendLegalOps));
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
if (options.decompose) {
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
|
@ -161,8 +161,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
|
||||
// Reduce variants of ops to a smaller set of primitives.
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
createReduceOpVariantsPass(options.extraLibrary));
|
||||
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass(
|
||||
options.extraLibrary, options.backendLegalOps));
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Remove dead global slots.
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
|
|
|
@ -403,8 +403,9 @@ namespace {
|
|||
struct ReduceOpVariantsPass
|
||||
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||
ReduceOpVariantsPass() = default;
|
||||
ReduceOpVariantsPass(StringRef extraLibrary) {
|
||||
ReduceOpVariantsPass(StringRef extraLibrary, ArrayRef<std::string> legalOps) {
|
||||
this->extraLibrary = extraLibrary.str();
|
||||
this->legalOps = legalOps;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
|
@ -439,13 +440,15 @@ struct ReduceOpVariantsPass
|
|||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.addIllegalOp<AtenArangeStartOutOp>();
|
||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable,
|
||||
&specializedNames](Operation *op) {
|
||||
if (isa<OperatorOp>(op)) {
|
||||
if (specializedNames.contains(cast<OperatorOp>(op).getNameAttr())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
target.addDynamicallyLegalOp<OperatorOp>([&](OperatorOp op) {
|
||||
auto opNameAttr = op.getNameAttr();
|
||||
return llvm::find(legalOps, opNameAttr.str()) != legalOps.end() &&
|
||||
!specializedNames.contains(opNameAttr);
|
||||
});
|
||||
|
||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
||||
Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
||||
(isa<OperatorOp>(op) &&
|
||||
operatorOpHasValueSemantics(cast<OperatorOp>(op),
|
||||
|
@ -479,6 +482,7 @@ struct ReduceOpVariantsPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) {
|
||||
return std::make_unique<ReduceOpVariantsPass>(extraLibrary);
|
||||
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary,
|
||||
ArrayRef<std::string> legalOps) {
|
||||
return std::make_unique<ReduceOpVariantsPass>(extraLibrary, legalOps);
|
||||
}
|
||||
|
|
|
@ -102,3 +102,36 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
|
|||
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch_aten_round_decimals
|
||||
// CHECK: %[[VAL2:.*]] = linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
|
||||
// CHECK-NEXT: %[[CONST_64:.*]] = arith.constant
|
||||
// CHECK-NEXT: %[[CONST_10:.*]] = arith.constant 1.000000e+01
|
||||
// CHECK-NEXT: %[[VAL4:.*]] = math.fpowi %[[CONST_10]], %[[CONST_64]]
|
||||
// CHECK-NEXT: %[[VAL5:.*]] = arith.mulf %[[IN]], %[[VAL4]]
|
||||
// CHECK-NEXT: %[[VAL6:.*]] = math.roundeven %[[VAL5]]
|
||||
// CHECK-NEXT: %[[VAL7:.*]] = arith.divf %[[VAL6]], %[[VAL4]]
|
||||
// CHECK-NEXT: linalg.yield %[[VAL7]]
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast %[[VAL2]]
|
||||
// CHECK-NEXT: %[[VAL3:.*]] = torch_c.from_builtin_tensor %[[CAST]]
|
||||
// CHECK-NEXT: return %[[VAL3]]
|
||||
func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32>
|
||||
return %1 : !torch.vtensor<[1, 1024,1024,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test that unhandled versions of `torch.operator` op are not legalized.
|
||||
func.func @torch.prims.device_put(%arg0: !torch.vtensor<[1,77],si64>) -> !torch.vtensor<[1,77],si64> {
|
||||
%cuda3A0 = torch.constant.device "cuda:0"
|
||||
// expected-error @+1 {{failed to legalize operation 'torch.operator' that was explicitly marked illegal}}
|
||||
%0 = torch.operator "torch.prims.device_put"(%arg0, %cuda3A0) : (!torch.vtensor<[1,77],si64>, !torch.Device) -> !torch.vtensor<[1,77],si64>
|
||||
%int4 = torch.constant.int 4
|
||||
%1 = torch.prims.convert_element_type %0, %int4 : !torch.vtensor<[1,77],si64>, !torch.int -> !torch.vtensor<[1,77],si64>
|
||||
return %1 : !torch.vtensor<[1,77],si64>
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax,torch.aten.round.decimals})' -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.square
|
||||
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
|
@ -25,3 +25,11 @@ func.func @torch.uint8(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[
|
|||
%1 = torch.aten.reshape %arg0, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
// Test that "torch.aten.round.decimals" was considered legal after explicitly specifying it in pass options.
|
||||
// CHECK-LABEL: func.func @torch_aten_round_decimals
|
||||
func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32>
|
||||
return %1 : !torch.vtensor<[1, 1024,1024,3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue