From e23cabf3a90332e8f05063e25a24a050f1d69232 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Mon, 8 Nov 2021 03:03:22 -0500 Subject: [PATCH] Add log2 --- e2e_testing/torchscript/elementwise.py | 16 +++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 28 +++++++++++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 6 ++-- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 5 files changed, 50 insertions(+), 3 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 6fdae0271..982febed1 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -429,3 +429,19 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseToDtypeF32ToI64Module()) def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + +class ElementwiseLog2Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.log2(a) + +@register_test_case(module_factory=lambda: ElementwiseLog2Module()) +def ElementwiseLog2Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 113409b69..672ed7fc7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -850,6 +850,34 @@ def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [ let assemblyFormat = "$self `,` $min `,` $max attr-dict `:` type($self) `,` type($min) `,` type($max) `->` type($result)"; } +def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::log2 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::log2_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 0fdc6e11c..96c7b6d5a 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1284,6 +1284,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); + if (isa(op)) + return b.create(loc, payloadArgs[0]); if (isa(op)) { Type elementType = payloadArgs[0].getType(); auto one = b.create(loc, FloatAttr::get(elementType, 1)); @@ -1700,7 +1702,7 @@ struct ConvertElementwiseOp : ConversionPattern { AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, - AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>(op)) + AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -2861,7 +2863,7 @@ public: AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, - AtenPowTensorScalarOp>(); + AtenPowTensorScalarOp, AtenLog2Op>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 1fd8b2d7f..c6296c2e0 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -230,7 +230,7 @@ public: AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, - AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp>(op)) { + AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 53fbd1345..19bbe3a1b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -465,6 +465,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", + "aten::log2 : (Tensor) -> (Tensor)", ]: emit_with_mutating_variants(key) # Elementwise tensor compute ops that don't have the standard mutating