diff --git a/frontends/pytorch/examples/mul_maximum_e2e.py b/frontends/pytorch/examples/mul_maximum_e2e.py new file mode 100644 index 000000000..49bafa5f2 --- /dev/null +++ b/frontends/pytorch/examples/mul_maximum_e2e.py @@ -0,0 +1,24 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import sys +import torch +import torch_mlir + +lhs = torch.ones((4, 6, 1)) +rhs = torch.ones((1, 1, 3)) * 0.6 +bias = torch.ones((1, 1, 3)) * 0.2 +threshold = torch.tensor((0.75, 0.25, 0.10)) + +mb = torch_mlir.ModuleBuilder() +with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f: + result = torch.maximum(lhs * rhs, threshold) + result = result + bias + f.returns([result]) + +print(f"Result(f{result.size()}) = {result}", file=sys.stderr) +# TODO: Currently need to route through: +# npcomp-opt -aten-recognize-kernels -convert-aten-to-tcf \ +# -numpy-public-functions-to-tensor -canonicalize +mb.module.operation.print() diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py index efb25fcb3..5a2e97000 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py @@ -64,6 +64,9 @@ def generate_ops(g: "OpGenerator"): g.ordinary_binary_op("aten::true_divide(Tensor,Tensor)", "TrueDivideOp", "true_divide") + g.ordinary_binary_op("aten::maximum(Tensor,Tensor)", "MaximumOp", "maximum") + g.ordinary_binary_op("aten::minimum(Tensor,Tensor)", "MinimumOp", "minimum") + # Unary-ops. These are all the same so just name munge them. g.print_banner("Unary arithmetic ops") for uname in [ diff --git a/include/npcomp/Conversion/ATenToTCF/Passes.h b/include/npcomp/Conversion/ATenToTCF/Passes.h new file mode 100644 index 000000000..e4920ba63 --- /dev/null +++ b/include/npcomp/Conversion/ATenToTCF/Passes.h @@ -0,0 +1,21 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_CONVERSION_ATENTOTCF_PASSES_H +#define NPCOMP_CONVERSION_ATENTOTCF_PASSES_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace NPCOMP { +std::unique_ptr> createConvertATenToTCFPass(); +} +} // namespace mlir + +#endif // NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H diff --git a/include/npcomp/Conversion/ATenToTCF/Patterns.h b/include/npcomp/Conversion/ATenToTCF/Patterns.h new file mode 100644 index 000000000..a26f90a23 --- /dev/null +++ b/include/npcomp/Conversion/ATenToTCF/Patterns.h @@ -0,0 +1,30 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H +#define NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H + +#include + +namespace mlir { + +class MLIRContext; +class OwningRewritePatternList; + +namespace NPCOMP { + +/// Populates patterns for converting core ATen ops to TCF. These patterns +/// cover core arithmetic ops that are on the order of 1:1 representationally. +/// More advanced patterns are managed elsewhere. +void populateCoreATenToTCFPatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + +} // namespace NPCOMP +} // namespace mlir + +#endif // NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H diff --git a/include/npcomp/Conversion/Passes.td b/include/npcomp/Conversion/Passes.td index d7d34d62f..14801bd2b 100644 --- a/include/npcomp/Conversion/Passes.td +++ b/include/npcomp/Conversion/Passes.td @@ -12,12 +12,12 @@ include "mlir/Pass/PassBase.td" //===----------------------------------------------------------------------===// -// TCFToTCP +// ATen conversions //===----------------------------------------------------------------------===// -def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> { - let summary = "Convert TCF to TCP"; - let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()"; +def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> { + let summary = "Convert recognized ATen to TCF ops"; + let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()"; } //===----------------------------------------------------------------------===// @@ -38,6 +38,15 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> { let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()"; } +//===----------------------------------------------------------------------===// +// TCFToTCP +//===----------------------------------------------------------------------===// + +def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> { + let summary = "Convert TCF to TCP"; + let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()"; +} + //===----------------------------------------------------------------------===// // Conditionally compiled IREE backend passes //===----------------------------------------------------------------------===// diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc index 5b6d453e0..44cd8b5ea 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc @@ -152,6 +152,44 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() { return metadata; } +Torch::KernelMetadata MaximumOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &MaximumOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::maximum"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata MinimumOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &MinimumOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::minimum"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + // ----------------------------------------------------------------------------- // Unary arithmetic ops // ----------------------------------------------------------------------------- diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td index 34a0099bb..6c5b90c96 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td @@ -96,6 +96,28 @@ def aten_TrueDivideOp: aten_Op<"true_divide", [NoSideEffect, DeclareOpInterfaceM ); } +def aten_MaximumOp: aten_Op<"maximum", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::maximum"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_MinimumOp: aten_Op<"minimum", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::minimum"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$other + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + // ----------------------------------------------------------------------------- // Unary arithmetic ops // ----------------------------------------------------------------------------- diff --git a/lib/Conversion/ATenToTCF/CMakeLists.txt b/lib/Conversion/ATenToTCF/CMakeLists.txt new file mode 100644 index 000000000..bd73354b6 --- /dev/null +++ b/lib/Conversion/ATenToTCF/CMakeLists.txt @@ -0,0 +1,17 @@ +add_npcomp_conversion_library(NPCOMPATenToTCF + ConvertATenToTCFPass.cpp + CoreOpConversionPatterns.cpp + + DEPENDS + NPCOMPConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + NPCOMPATenDialect + NPCOMPTCFDialect +) diff --git a/lib/Conversion/ATenToTCF/ConvertATenToTCFPass.cpp b/lib/Conversion/ATenToTCF/ConvertATenToTCFPass.cpp new file mode 100644 index 000000000..91a5c612a --- /dev/null +++ b/lib/Conversion/ATenToTCF/ConvertATenToTCFPass.cpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Conversion/ATenToTCF/Passes.h" + +#include "../PassDetail.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "npcomp/Conversion/ATenToTCF/Patterns.h" +#include "npcomp/Dialect/TCF/IR/TCFDialect.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +namespace { + +class ConvertATenToTCF : public ConvertATenToTCFBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + FuncOp funcOp = getOperation(); + MLIRContext *context = &getContext(); + OwningRewritePatternList patterns; + populateCoreATenToTCFPatterns(context, patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace + +std::unique_ptr> +mlir::NPCOMP::createConvertATenToTCFPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/ATenToTCF/CoreOpConversionPatterns.cpp b/lib/Conversion/ATenToTCF/CoreOpConversionPatterns.cpp new file mode 100644 index 000000000..5a8acf350 --- /dev/null +++ b/lib/Conversion/ATenToTCF/CoreOpConversionPatterns.cpp @@ -0,0 +1,77 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Conversion/ATenToTCF/Patterns.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "npcomp/Dialect/ATen/IR/ATenDialect.h" +#include "npcomp/Dialect/TCF/IR/TCFOps.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +namespace { + +/// The ATen AddOp actually has three arguments: +/// self, other, alpha +/// Alpha is an integer that is multiplied by 'other' prior to adding. +class ConvertATenAdd : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(aten::AddOp srcAddOp, + PatternRewriter &rewriter) const override { + // Special case: Match when alpha is constant 1, which is the default, + // quite common and maps directly to a TCF add. Note that regardless of + // the type of self/other (i.e. if they are float), alpha emits as an + // integer with value 1 when defaulted. It is this specific case that we + // are detecting (default value) and will leave all others to the fully + // generic conversion. + APInt alphaValue; + if (matchPattern(srcAddOp.alpha(), m_ConstantInt(&alphaValue)) && + alphaValue.getZExtValue() == 1) { + rewriter.replaceOpWithNewOp( + srcAddOp, srcAddOp.getResult().getType(), srcAddOp.self(), + srcAddOp.other()); + return success(); + } + + return rewriter.notifyMatchFailure( + srcAddOp, "aten.add to tcf.add currently only supports alpha == 1"); + } +}; + +/// Common conversion template for true binary elementwise ops. +/// This does not apply to the handful of not-actually-binary PyTorch ops that +/// have broadcastable self/other operands but may have additional parameters. +template +class ConvertBinaryElementwise : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SourceOp srcOp, + PatternRewriter &rewriter) const override { + auto operands = srcOp.getOperation()->getOperands(); + auto results = srcOp.getOperation()->getResults(); + assert(operands.size() == 2 && "expected true binary op"); + assert(results.size() == 1 && "expected single result op"); + Type resultType = results[0].getType(); + rewriter.replaceOpWithNewOp( + srcOp, resultType, srcOp.getOperand(0), srcOp.getOperand(1)); + return success(); + } +}; + +} // namespace + +void mlir::NPCOMP::populateCoreATenToTCFPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert(context); + patterns.insert>(context); + patterns.insert>( + context); +} diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 0ed163353..7e23c001f 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(ATenToTCF) add_subdirectory(BasicpyToStd) add_subdirectory(NumpyToTCF) add_subdirectory(TCFToTCP) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 2925d85c6..333fe3be2 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -8,6 +8,7 @@ #include "npcomp/Conversion/Passes.h" +#include "npcomp/Conversion/ATenToTCF/Passes.h" #include "npcomp/Conversion/BasicpyToStd/Passes.h" #include "npcomp/Conversion/NumpyToTCF/Passes.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" diff --git a/test/Conversion/ATenToTCF/core_op_conversions.mlir b/test/Conversion/ATenToTCF/core_op_conversions.mlir new file mode 100644 index 000000000..38c301dca --- /dev/null +++ b/test/Conversion/ATenToTCF/core_op_conversions.mlir @@ -0,0 +1,26 @@ +// RUN: npcomp-opt <%s -convert-aten-to-tcf | FileCheck %s --dump-input=fail + +// CHECK-LABEL: @binary_elementwise_ops +// NOTE: These are all template expanded, so just testing an examplar op and +// special cases. +func @binary_elementwise_ops(%arg0: tensor<4x6x1xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> { + // CHECK: tcf.mul %arg0, %arg1 : (tensor<4x6x1xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32> + %0 = "aten.mul"(%arg0, %arg1) : (tensor<4x6x1xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32> + return %0 : tensor<4x6x3xf32> +} + +// CHECK-LABEL: @add_alpha_constant1 +func @add_alpha_constant1(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> { + %c1_i64 = constant 1 : i64 + // CHECK: tcf.add %arg0, %arg1 : (tensor<4x6x3xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32> + %0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<4x6x3xf32> + return %0 : tensor<4x6x3xf32> +} + +// CHECK-LABEL: @add_alpha_non_constant1 +func @add_alpha_non_constant1(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> { + %c1_i64 = constant 2 : i64 + // CHECK: "aten.add" + %0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<4x6x3xf32> + return %0 : tensor<4x6x3xf32> +}