mirror of https://github.com/llvm/torch-mlir
Add aten.maximum op and conversions from aten->tcf.
* Conversions are very simple, suporting mul, maximum and add (alpha=1 only). * Example added with pass pipeline needed to run. * Much missing off of the golden path but sufficient for such simple cases.pull/109/head
parent
6c702b149f
commit
e60dc2470e
|
@ -0,0 +1,24 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
lhs = torch.ones((4, 6, 1))
|
||||
rhs = torch.ones((1, 1, 3)) * 0.6
|
||||
bias = torch.ones((1, 1, 3)) * 0.2
|
||||
threshold = torch.tensor((0.75, 0.25, 0.10))
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
|
||||
result = torch.maximum(lhs * rhs, threshold)
|
||||
result = result + bias
|
||||
f.returns([result])
|
||||
|
||||
print(f"Result(f{result.size()}) = {result}", file=sys.stderr)
|
||||
# TODO: Currently need to route through:
|
||||
# npcomp-opt -aten-recognize-kernels -convert-aten-to-tcf \
|
||||
# -numpy-public-functions-to-tensor -canonicalize
|
||||
mb.module.operation.print()
|
|
@ -64,6 +64,9 @@ def generate_ops(g: "OpGenerator"):
|
|||
g.ordinary_binary_op("aten::true_divide(Tensor,Tensor)", "TrueDivideOp",
|
||||
"true_divide")
|
||||
|
||||
g.ordinary_binary_op("aten::maximum(Tensor,Tensor)", "MaximumOp", "maximum")
|
||||
g.ordinary_binary_op("aten::minimum(Tensor,Tensor)", "MinimumOp", "minimum")
|
||||
|
||||
# Unary-ops. These are all the same so just name munge them.
|
||||
g.print_banner("Unary arithmetic ops")
|
||||
for uname in [
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_ATENTOTCF_PASSES_H
|
||||
#define NPCOMP_CONVERSION_ATENTOTCF_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertATenToTCFPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
|
|
@ -0,0 +1,30 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
|
||||
#define NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class MLIRContext;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
namespace NPCOMP {
|
||||
|
||||
/// Populates patterns for converting core ATen ops to TCF. These patterns
|
||||
/// cover core arithmetic ops that are on the order of 1:1 representationally.
|
||||
/// More advanced patterns are managed elsewhere.
|
||||
void populateCoreATenToTCFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
|
|
@ -12,12 +12,12 @@
|
|||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TCFToTCP
|
||||
// ATen conversions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
|
||||
let summary = "Convert TCF to TCP";
|
||||
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
||||
def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> {
|
||||
let summary = "Convert recognized ATen to TCF ops";
|
||||
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -38,6 +38,15 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
|
|||
let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TCFToTCP
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
|
||||
let summary = "Convert TCF to TCP";
|
||||
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conditionally compiled IREE backend passes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -152,6 +152,44 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata MaximumOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &MaximumOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::maximum";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata MinimumOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &MinimumOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::minimum";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Unary arithmetic ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -96,6 +96,28 @@ def aten_TrueDivideOp: aten_Op<"true_divide", [NoSideEffect, DeclareOpInterfaceM
|
|||
);
|
||||
}
|
||||
|
||||
def aten_MaximumOp: aten_Op<"maximum", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::maximum";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_MinimumOp: aten_Op<"minimum", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::minimum";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Unary arithmetic ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
add_npcomp_conversion_library(NPCOMPATenToTCF
|
||||
ConvertATenToTCFPass.cpp
|
||||
CoreOpConversionPatterns.cpp
|
||||
|
||||
DEPENDS
|
||||
NPCOMPConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
NPCOMPATenDialect
|
||||
NPCOMPTCFDialect
|
||||
)
|
|
@ -0,0 +1,41 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Conversion/ATenToTCF/Patterns.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
namespace {
|
||||
|
||||
class ConvertATenToTCF : public ConvertATenToTCFBase<ConvertATenToTCF> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tcf::TCFDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
FuncOp funcOp = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
populateCoreATenToTCFPatterns(context, patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createConvertATenToTCFPass() {
|
||||
return std::make_unique<ConvertATenToTCF>();
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Conversion/ATenToTCF/Patterns.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
namespace {
|
||||
|
||||
/// The ATen AddOp actually has three arguments:
|
||||
/// self, other, alpha
|
||||
/// Alpha is an integer that is multiplied by 'other' prior to adding.
|
||||
class ConvertATenAdd : public OpRewritePattern<aten::AddOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(aten::AddOp srcAddOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Special case: Match when alpha is constant 1, which is the default,
|
||||
// quite common and maps directly to a TCF add. Note that regardless of
|
||||
// the type of self/other (i.e. if they are float), alpha emits as an
|
||||
// integer with value 1 when defaulted. It is this specific case that we
|
||||
// are detecting (default value) and will leave all others to the fully
|
||||
// generic conversion.
|
||||
APInt alphaValue;
|
||||
if (matchPattern(srcAddOp.alpha(), m_ConstantInt(&alphaValue)) &&
|
||||
alphaValue.getZExtValue() == 1) {
|
||||
rewriter.replaceOpWithNewOp<tcf::AddOp>(
|
||||
srcAddOp, srcAddOp.getResult().getType(), srcAddOp.self(),
|
||||
srcAddOp.other());
|
||||
return success();
|
||||
}
|
||||
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcAddOp, "aten.add to tcf.add currently only supports alpha == 1");
|
||||
}
|
||||
};
|
||||
|
||||
/// Common conversion template for true binary elementwise ops.
|
||||
/// This does not apply to the handful of not-actually-binary PyTorch ops that
|
||||
/// have broadcastable self/other operands but may have additional parameters.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class ConvertBinaryElementwise : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(SourceOp srcOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto operands = srcOp.getOperation()->getOperands();
|
||||
auto results = srcOp.getOperation()->getResults();
|
||||
assert(operands.size() == 2 && "expected true binary op");
|
||||
assert(results.size() == 1 && "expected single result op");
|
||||
Type resultType = results[0].getType();
|
||||
rewriter.replaceOpWithNewOp<TargetOp>(
|
||||
srcOp, resultType, srcOp.getOperand(0), srcOp.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::NPCOMP::populateCoreATenToTCFPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<ConvertATenAdd>(context);
|
||||
patterns.insert<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
|
||||
context);
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
add_subdirectory(ATenToTCF)
|
||||
add_subdirectory(BasicpyToStd)
|
||||
add_subdirectory(NumpyToTCF)
|
||||
add_subdirectory(TCFToTCP)
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "npcomp/Conversion/Passes.h"
|
||||
|
||||
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
||||
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
// RUN: npcomp-opt <%s -convert-aten-to-tcf | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: @binary_elementwise_ops
|
||||
// NOTE: These are all template expanded, so just testing an examplar op and
|
||||
// special cases.
|
||||
func @binary_elementwise_ops(%arg0: tensor<4x6x1xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
|
||||
// CHECK: tcf.mul %arg0, %arg1 : (tensor<4x6x1xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
|
||||
%0 = "aten.mul"(%arg0, %arg1) : (tensor<4x6x1xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
|
||||
return %0 : tensor<4x6x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_alpha_constant1
|
||||
func @add_alpha_constant1(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
|
||||
%c1_i64 = constant 1 : i64
|
||||
// CHECK: tcf.add %arg0, %arg1 : (tensor<4x6x3xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
|
||||
%0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<4x6x3xf32>
|
||||
return %0 : tensor<4x6x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_alpha_non_constant1
|
||||
func @add_alpha_non_constant1(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
|
||||
%c1_i64 = constant 2 : i64
|
||||
// CHECK: "aten.add"
|
||||
%0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<4x6x3xf32>
|
||||
return %0 : tensor<4x6x3xf32>
|
||||
}
|
Loading…
Reference in New Issue