Add aten.maximum op and conversions from aten->tcf.

* Conversions are very simple, suporting mul, maximum and add (alpha=1 only).
* Example added with pass pipeline needed to run.
* Much missing off of the golden path but sufficient for such simple cases.
pull/109/head
Stella Laurenzo 2020-11-04 16:54:52 -08:00
parent 6c702b149f
commit e60dc2470e
13 changed files with 314 additions and 4 deletions

View File

@ -0,0 +1,24 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import sys
import torch
import torch_mlir
lhs = torch.ones((4, 6, 1))
rhs = torch.ones((1, 1, 3)) * 0.6
bias = torch.ones((1, 1, 3)) * 0.2
threshold = torch.tensor((0.75, 0.25, 0.10))
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
result = torch.maximum(lhs * rhs, threshold)
result = result + bias
f.returns([result])
print(f"Result(f{result.size()}) = {result}", file=sys.stderr)
# TODO: Currently need to route through:
# npcomp-opt -aten-recognize-kernels -convert-aten-to-tcf \
# -numpy-public-functions-to-tensor -canonicalize
mb.module.operation.print()

View File

@ -64,6 +64,9 @@ def generate_ops(g: "OpGenerator"):
g.ordinary_binary_op("aten::true_divide(Tensor,Tensor)", "TrueDivideOp",
"true_divide")
g.ordinary_binary_op("aten::maximum(Tensor,Tensor)", "MaximumOp", "maximum")
g.ordinary_binary_op("aten::minimum(Tensor,Tensor)", "MinimumOp", "minimum")
# Unary-ops. These are all the same so just name munge them.
g.print_banner("Unary arithmetic ops")
for uname in [

View File

@ -0,0 +1,21 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_CONVERSION_ATENTOTCF_PASSES_H
#define NPCOMP_CONVERSION_ATENTOTCF_PASSES_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<FuncOp>> createConvertATenToTCFPass();
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H

View File

@ -0,0 +1,30 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
#define NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H
#include <memory>
namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
namespace NPCOMP {
/// Populates patterns for converting core ATen ops to TCF. These patterns
/// cover core arithmetic ops that are on the order of 1:1 representationally.
/// More advanced patterns are managed elsewhere.
void populateCoreATenToTCFPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
} // namespace NPCOMP
} // namespace mlir
#endif // NPCOMP_CONVERSION_ATENTOTCF_PATTERNS_H

View File

@ -12,12 +12,12 @@
include "mlir/Pass/PassBase.td"
//===----------------------------------------------------------------------===//
// TCFToTCP
// ATen conversions
//===----------------------------------------------------------------------===//
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
let summary = "Convert TCF to TCP";
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> {
let summary = "Convert recognized ATen to TCF ops";
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
}
//===----------------------------------------------------------------------===//
@ -38,6 +38,15 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()";
}
//===----------------------------------------------------------------------===//
// TCFToTCP
//===----------------------------------------------------------------------===//
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
let summary = "Convert TCF to TCP";
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
}
//===----------------------------------------------------------------------===//
// Conditionally compiled IREE backend passes
//===----------------------------------------------------------------------===//

View File

@ -152,6 +152,44 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
return metadata;
}
Torch::KernelMetadata MaximumOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &MaximumOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::maximum";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata MinimumOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &MinimumOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::minimum";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
// -----------------------------------------------------------------------------
// Unary arithmetic ops
// -----------------------------------------------------------------------------

View File

@ -96,6 +96,28 @@ def aten_TrueDivideOp: aten_Op<"true_divide", [NoSideEffect, DeclareOpInterfaceM
);
}
def aten_MaximumOp: aten_Op<"maximum", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::maximum";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$other
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_MinimumOp: aten_Op<"minimum", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::minimum";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$other
);
let results = (outs
AnyTorchImmutableTensor
);
}
// -----------------------------------------------------------------------------
// Unary arithmetic ops
// -----------------------------------------------------------------------------

View File

@ -0,0 +1,17 @@
add_npcomp_conversion_library(NPCOMPATenToTCF
ConvertATenToTCFPass.cpp
CoreOpConversionPatterns.cpp
DEPENDS
NPCOMPConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTransforms
NPCOMPATenDialect
NPCOMPTCFDialect
)

View File

@ -0,0 +1,41 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/ATenToTCF/Passes.h"
#include "../PassDetail.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Conversion/ATenToTCF/Patterns.h"
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
using namespace mlir;
using namespace mlir::NPCOMP;
namespace {
class ConvertATenToTCF : public ConvertATenToTCFBase<ConvertATenToTCF> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tcf::TCFDialect>();
}
void runOnOperation() override {
FuncOp funcOp = getOperation();
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
populateCoreATenToTCFPatterns(context, patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertATenToTCFPass() {
return std::make_unique<ConvertATenToTCF>();
}

View File

@ -0,0 +1,77 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/ATenToTCF/Patterns.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
namespace {
/// The ATen AddOp actually has three arguments:
/// self, other, alpha
/// Alpha is an integer that is multiplied by 'other' prior to adding.
class ConvertATenAdd : public OpRewritePattern<aten::AddOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(aten::AddOp srcAddOp,
PatternRewriter &rewriter) const override {
// Special case: Match when alpha is constant 1, which is the default,
// quite common and maps directly to a TCF add. Note that regardless of
// the type of self/other (i.e. if they are float), alpha emits as an
// integer with value 1 when defaulted. It is this specific case that we
// are detecting (default value) and will leave all others to the fully
// generic conversion.
APInt alphaValue;
if (matchPattern(srcAddOp.alpha(), m_ConstantInt(&alphaValue)) &&
alphaValue.getZExtValue() == 1) {
rewriter.replaceOpWithNewOp<tcf::AddOp>(
srcAddOp, srcAddOp.getResult().getType(), srcAddOp.self(),
srcAddOp.other());
return success();
}
return rewriter.notifyMatchFailure(
srcAddOp, "aten.add to tcf.add currently only supports alpha == 1");
}
};
/// Common conversion template for true binary elementwise ops.
/// This does not apply to the handful of not-actually-binary PyTorch ops that
/// have broadcastable self/other operands but may have additional parameters.
template <typename SourceOp, typename TargetOp>
class ConvertBinaryElementwise : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SourceOp srcOp,
PatternRewriter &rewriter) const override {
auto operands = srcOp.getOperation()->getOperands();
auto results = srcOp.getOperation()->getResults();
assert(operands.size() == 2 && "expected true binary op");
assert(results.size() == 1 && "expected single result op");
Type resultType = results[0].getType();
rewriter.replaceOpWithNewOp<TargetOp>(
srcOp, resultType, srcOp.getOperand(0), srcOp.getOperand(1));
return success();
}
};
} // namespace
void mlir::NPCOMP::populateCoreATenToTCFPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<ConvertATenAdd>(context);
patterns.insert<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
context);
}

View File

@ -1,3 +1,4 @@
add_subdirectory(ATenToTCF)
add_subdirectory(BasicpyToStd)
add_subdirectory(NumpyToTCF)
add_subdirectory(TCFToTCP)

View File

@ -8,6 +8,7 @@
#include "npcomp/Conversion/Passes.h"
#include "npcomp/Conversion/ATenToTCF/Passes.h"
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"

View File

@ -0,0 +1,26 @@
// RUN: npcomp-opt <%s -convert-aten-to-tcf | FileCheck %s --dump-input=fail
// CHECK-LABEL: @binary_elementwise_ops
// NOTE: These are all template expanded, so just testing an examplar op and
// special cases.
func @binary_elementwise_ops(%arg0: tensor<4x6x1xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
// CHECK: tcf.mul %arg0, %arg1 : (tensor<4x6x1xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
%0 = "aten.mul"(%arg0, %arg1) : (tensor<4x6x1xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
return %0 : tensor<4x6x3xf32>
}
// CHECK-LABEL: @add_alpha_constant1
func @add_alpha_constant1(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
%c1_i64 = constant 1 : i64
// CHECK: tcf.add %arg0, %arg1 : (tensor<4x6x3xf32>, tensor<1x1x3xf32>) -> tensor<4x6x3xf32>
%0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<4x6x3xf32>
return %0 : tensor<4x6x3xf32>
}
// CHECK-LABEL: @add_alpha_non_constant1
func @add_alpha_non_constant1(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<4x6x3xf32> {
%c1_i64 = constant 2 : i64
// CHECK: "aten.add"
%0 = "aten.add"(%arg0, %arg1, %c1_i64) : (tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64) -> tensor<4x6x3xf32>
return %0 : tensor<4x6x3xf32>
}