mirror of https://github.com/llvm/torch-mlir
Add aten.maximum op and conversions from aten->tcf.
* Conversions are very simple, suporting mul, maximum and add (alpha=1 only). * Example added with pass pipeline needed to run. * Much missing off of the golden path but sufficient for such simple cases.pull/109/head
parent
6c702b149f
commit
e60dc2470e
|
@ -0,0 +1,24 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
lhs = torch.ones((4, 6, 1))
|
||||||
|
rhs = torch.ones((1, 1, 3)) * 0.6
|
||||||
|
bias = torch.ones((1, 1, 3)) * 0.2
|
||||||
|
threshold = torch.tensor((0.75, 0.25, 0.10))
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
|
||||||
|
result = torch.maximum(lhs * rhs, threshold)
|
||||||
|
result = result + bias
|
||||||
|
f.returns([result])
|
||||||
|
|
||||||
|
print(f"Result(f{result.size()}) = {result}", file=sys.stderr)
|
||||||
|
# TODO: Currently need to route through:
|
||||||
|
# npcomp-opt -aten-recognize-kernels -convert-aten-to-tcf \
|
||||||
|
# -numpy-public-functions-to-tensor -canonicalize
|
||||||
|
mb.module.operation.print()
|
|
@ -64,6 +64,9 @@ def generate_ops(g: "OpGenerator"):
|
||||||
g.ordinary_binary_op("aten::true_divide(Tensor,Tensor)", "TrueDivideOp",
|
g.ordinary_binary_op("aten::true_divide(Tensor,Tensor)", "TrueDivideOp",
|
||||||
"true_divide")
|
"true_divide")
|
||||||
|
|
||||||
|
g.ordinary_binary_op("aten::maximum(Tensor,Tensor)", "MaximumOp", "maximum")
|
||||||
|
g.ordinary_binary_op("aten::minimum(Tensor,Tensor)", "MinimumOp", "minimum")
|
||||||
|
|
||||||
# Unary-ops. These are all the same so just name munge them.
|
# Unary-ops. These are all the same so just name munge them.
|
||||||
g.print_banner("Unary arithmetic ops")
|
g.print_banner("Unary arithmetic ops")
|
||||||
for uname in [
|
for uname in [
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_CONVERSION_ATENTOTCF_PASSES_H
|
||||||
|
#define NPCOMP_CONVERSION_ATENTOTCF_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createConvertATenToTCFPass();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
|
|
@ -0,0 +1,30 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
|
||||||
|
#define NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class MLIRContext;
|
||||||
|
class OwningRewritePatternList;
|
||||||
|
|
||||||
|
namespace NPCOMP {
|
||||||
|
|
||||||
|
/// Populates patterns for converting core ATen ops to TCF. These patterns
|
||||||
|
/// cover core arithmetic ops that are on the order of 1:1 representationally.
|
||||||
|
/// More advanced patterns are managed elsewhere.
|
||||||
|
void populateCoreATenToTCFPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList &patterns);
|
||||||
|
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
|
|
@ -12,12 +12,12 @@
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TCFToTCP
|
// ATen conversions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
|
def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> {
|
||||||
let summary = "Convert TCF to TCP";
|
let summary = "Convert recognized ATen to TCF ops";
|
||||||
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -38,6 +38,15 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()";
|
let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TCFToTCP
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
|
||||||
|
let summary = "Convert TCF to TCP";
|
||||||
|
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Conditionally compiled IREE backend passes
|
// Conditionally compiled IREE backend passes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -152,6 +152,44 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata MaximumOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &MaximumOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::maximum";
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata MinimumOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &MinimumOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::minimum";
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Unary arithmetic ops
|
// Unary arithmetic ops
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
|
@ -96,6 +96,28 @@ def aten_TrueDivideOp: aten_Op<"true_divide", [NoSideEffect, DeclareOpInterfaceM
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def aten_MaximumOp: aten_Op<"maximum", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::maximum";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$self,
|
||||||
|
AnyTorchImmutableTensor:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_MinimumOp: aten_Op<"minimum", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::minimum";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$self,
|
||||||
|
AnyTorchImmutableTensor:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Unary arithmetic ops
|
// Unary arithmetic ops
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
add_npcomp_conversion_library(NPCOMPATenToTCF
|
||||||
|
ConvertATenToTCFPass.cpp
|
||||||
|
CoreOpConversionPatterns.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
NPCOMPConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRTransforms
|
||||||
|
NPCOMPATenDialect
|
||||||
|
NPCOMPTCFDialect
|
||||||
|
)
|
|
@ -0,0 +1,41 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "npcomp/Conversion/ATenToTCF/Patterns.h"
|
||||||
|
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ConvertATenToTCF : public ConvertATenToTCFBase<ConvertATenToTCF> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<tcf::TCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
FuncOp funcOp = getOperation();
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
populateCoreATenToTCFPatterns(context, patterns);
|
||||||
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::createConvertATenToTCFPass() {
|
||||||
|
return std::make_unique<ConvertATenToTCF>();
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Conversion/ATenToTCF/Patterns.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/// The ATen AddOp actually has three arguments:
|
||||||
|
/// self, other, alpha
|
||||||
|
/// Alpha is an integer that is multiplied by 'other' prior to adding.
|
||||||
|
class ConvertATenAdd : public OpRewritePattern<aten::AddOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(aten::AddOp srcAddOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Special case: Match when alpha is constant 1, which is the default,
|
||||||
|
// quite common and maps directly to a TCF add. Note that regardless of
|
||||||
|
// the type of self/other (i.e. if they are float), alpha emits as an
|
||||||
|
// integer with value 1 when defaulted. It is this specific case that we
|
||||||
|
// are detecting (default value) and will leave all others to the fully
|
||||||
|
// generic conversion.
|
||||||
|
APInt alphaValue;
|
||||||
|
if (matchPattern(srcAddOp.alpha(), m_ConstantInt(&alphaValue)) &&
|
||||||
|
alphaValue.getZExtValue() == 1) {
|
||||||
|
rewriter.replaceOpWithNewOp<tcf::AddOp>(
|
||||||
|
srcAddOp, srcAddOp.getResult().getType(), srcAddOp.self(),
|
||||||
|
srcAddOp.other());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
srcAddOp, "aten.add to tcf.add currently only supports alpha == 1");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Common conversion template for true binary elementwise ops.
|
||||||
|
/// This does not apply to the handful of not-actually-binary PyTorch ops that
|
||||||
|
/// have broadcastable self/other operands but may have additional parameters.
|
||||||
|
template <typename SourceOp, typename TargetOp>
|
||||||
|
class ConvertBinaryElementwise : public OpRewritePattern<SourceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(SourceOp srcOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto operands = srcOp.getOperation()->getOperands();
|
||||||
|
auto results = srcOp.getOperation()->getResults();
|
||||||
|
assert(operands.size() == 2 && "expected true binary op");
|
||||||
|
assert(results.size() == 1 && "expected single result op");
|
||||||
|
Type resultType = results[0].getType();
|
||||||
|
rewriter.replaceOpWithNewOp<TargetOp>(
|
||||||
|
srcOp, resultType, srcOp.getOperand(0), srcOp.getOperand(1));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::NPCOMP::populateCoreATenToTCFPatterns(
|
||||||
|
MLIRContext *context, OwningRewritePatternList &patterns) {
|
||||||
|
patterns.insert<ConvertATenAdd>(context);
|
||||||
|
patterns.insert<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||||
|
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
|
||||||
|
context);
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
add_subdirectory(ATenToTCF)
|
||||||
add_subdirectory(BasicpyToStd)
|
add_subdirectory(BasicpyToStd)
|
||||||
add_subdirectory(NumpyToTCF)
|
add_subdirectory(NumpyToTCF)
|
||||||
add_subdirectory(TCFToTCP)
|
add_subdirectory(TCFToTCP)
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include "npcomp/Conversion/Passes.h"
|
#include "npcomp/Conversion/Passes.h"
|
||||||
|
|
||||||
|
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||||
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
||||||
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
// RUN: npcomp-opt <%s -convert-aten-to-tcf | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK-LABEL: @binary_elementwise_ops
|
||||||
|
// NOTE: These are all template expanded, so just testing an examplar op and
|
||||||
|
// special cases.
|
||||||
|
func @binary_elementwise_ops(%arg0: tensor<4x6x1xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
|
||||||
|
// CHECK: tcf.mul %arg0, %arg1 : (tensor<4x6x1xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
|
||||||
|
%0 = "aten.mul"(%arg0, %arg1) : (tensor<4x6x1xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
|
||||||
|
return %0 : tensor<4x6x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @add_alpha_constant1
|
||||||
|
func @add_alpha_constant1(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
|
||||||
|
%c1_i64 = constant 1 : i64
|
||||||
|
// CHECK: tcf.add %arg0, %arg1 : (tensor<4x6x3xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
|
||||||
|
%0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<4x6x3xf32>
|
||||||
|
return %0 : tensor<4x6x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @add_alpha_non_constant1
|
||||||
|
func @add_alpha_non_constant1(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
|
||||||
|
%c1_i64 = constant 2 : i64
|
||||||
|
// CHECK: "aten.add"
|
||||||
|
%0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<4x6x3xf32>
|
||||||
|
return %0 : tensor<4x6x3xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue