Add torch-adjust-calling-conventions pass.

This pass incorporates torch.type_bound info and also removes NoneType
returns (eventually it will rewrite tuple types too, but can't yet
because !basicpy.TupleType doesn't track element types).

Recommend looking at adjust-calling-conventions.mlir first to see what
it is doing, and holding your nose for the implementation of the pass.
I decided to implement this with the conversion framework, because it
gives us *some* goodies for type conversion -- mainly avoiding large
amounts of tricky RAUW dances. Unfortunately, the conversion framework
isn't a perfect fit for a couple reasons:
- the incorporation of torch.type_bound is a context-sensitive rewrite
  (requires looking at the arg attr, not just the type).
- NoneType conversion is 1->0, which requires some special handling
- (not implemented yet) 1->N tuple type conversions require special
  handling.
It's a little bit scary, but on balance doing it the other way would
have its own downsides.
pull/202/head
Sean Silva 2021-04-01 17:36:18 -07:00
parent 464feacba9
commit 30356c41c8
10 changed files with 375 additions and 0 deletions

View File

@ -15,6 +15,7 @@
#include "mlir/IR/FunctionSupport.h" #include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Typing/Analysis/CPA/Interfaces.h" #include "npcomp/Typing/Analysis/CPA/Interfaces.h"

View File

@ -12,6 +12,7 @@
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td" include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
include "npcomp/Typing/Analysis/CPA/Interfaces.td" include "npcomp/Typing/Analysis/CPA/Interfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"
//----------------------------------------------------------------------------// //----------------------------------------------------------------------------//
@ -35,6 +36,25 @@ def Numpy_NarrowOp : Numpy_Op<"narrow", []> {
}]; }];
} }
def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
NoSideEffect]> {
let summary = "Adds/removes static information from an array type.";
let description = [{
This op does not imply any runtime code. Semantically it is an identity
function.
}];
let arguments = (ins
Numpy_AnyArray:$operand
);
let results = (outs
Numpy_AnyArray:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
}
//----------------------------------------------------------------------------// //----------------------------------------------------------------------------//
// NdArray type handling // NdArray type handling
//----------------------------------------------------------------------------// //----------------------------------------------------------------------------//

View File

@ -26,6 +26,8 @@ createPrepareForGlobalizeObjectGraphPass();
/// See the documentation on torch-globalize-object-graph for more details. /// See the documentation on torch-globalize-object-graph for more details.
void createGlobalizePipeline(OpPassManager &pm); void createGlobalizePipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
} // namespace Torch } // namespace Torch
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.

View File

@ -101,4 +101,28 @@ def PrepareForGlobalizeObjectGraph
}]; }];
} }
def AdjustCallingConventions
: Pass<"torch-adjust-calling-conventions", "ModuleOp"> {
let summary = "Adjust the calling conventions of functions";
let constructor = "mlir::NPCOMP::Torch::createAdjustCallingConventionsPass()";
let description = [{
Adjusts the calling conventions of functions in the module, with the aim of
preparing them for backends and further lowering passes. As this changes
the module calling convention, it should be considered a legalization
step towards reaching IR that is suitable for an appropriate backend.
All transformations are context-free and suitable for documenting
at the user level if needed to clarify the eventual calling convention
of compiled artifacts.
This is not an optimization.
The transformations performed are:
- `torch.type_bound` annotations are incorporated into the type of the
function arguments, which should be `!numpy.ndarray<...>`'s.
- Python-isms are rewritten to MLIR-isms
- NoneType return is rewritten to the absence of a return value.
- (Not implemented yet) Tuple return is rewritten to multiple return
values
}];
}
#endif // NPCOMP_TORCH_PASSES #endif // NPCOMP_TORCH_PASSES

View File

@ -11,6 +11,7 @@
#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
@ -58,6 +59,24 @@ void BuiltinUfuncCallOp::addCPAConstraints(Typing::CPA::Context &context) {
} }
} }
//----------------------------------------------------------------------------//
// StaticInfoCast
//----------------------------------------------------------------------------//
bool StaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
mlir::TypeRange outputs) {
auto input = inputs[0].cast<NdArrayType>();
auto output = outputs[0].cast<NdArrayType>();
if (input.getOptionalShape() && output.getOptionalShape()) {
if (failed(verifyCompatibleShape(*input.getOptionalShape(),
*output.getOptionalShape())))
return false;
}
return input.getDtype() == output.getDtype() ||
input.getDtype().isa<AnyDtypeType>() ||
output.getDtype().isa<AnyDtypeType>();
}
//----------------------------------------------------------------------------// //----------------------------------------------------------------------------//
// CreateArrayFromTensorOp // CreateArrayFromTensorOp
//----------------------------------------------------------------------------// //----------------------------------------------------------------------------//

View File

@ -0,0 +1,250 @@
//===- AdjustCallingConventions.cpp ------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
// Map from func name and arg index to the type bound for that arg.
// This is needed because to rewrite calls, we need the non-local information
// from the func definition.
// We also benefit from populating this all at once, which avoids ordering
// issues between rewriting of func ops vs call ops.
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type> ;
namespace {
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(FuncOp func, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = func.getContext();
auto typeBoundIdent = Identifier::get("torch.type_bound", context);
TypeConverter::SignatureConversion conversion(func.getNumArguments());
// The TypeConverter hooks for type conversion are "context free", so we
// cannot use the usual helpers here for populating SignatureConversion and
// new result types.
//
// The incoporation of the torch.type_bound arg attr is context-dependent.
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
if (auto ndarray = type.value().dyn_cast<Numpy::NdArrayType>()) {
auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
conversion.addInputs(type.index(), typeBoundAttr
? typeBoundAttr.getValue()
: type.value());
continue;
// type is attached to ndarray type.
// TODO: check if more specific?
} else if (auto none = type.value().dyn_cast<Basicpy::NoneType>()) {
continue;
}
// TODO: add tuple type.
conversion.addInputs(type.index(), type.value());
}
SmallVector<Type> newResultTypes;
for (auto type : func.getType().getResults()) {
if (auto none = type.dyn_cast<Basicpy::NoneType>()) {
continue;
}
newResultTypes.push_back(type);
}
rewriter.applySignatureConversion(&func.getBody(), conversion,
typeConverter);
rewriter.updateRootInPlace(func, [&] {
func.setType(FunctionType::get(
getContext(), conversion.getConvertedTypes(), newResultTypes));
// Clear out the type bounds, now that the type incorporates them.
for (int i = 0, e = func.getNumArguments(); i != e; i++)
func.removeArgAttr(i, typeBoundIdent);
});
return success();
}
};
} // namespace
namespace {
class AdjustCallingConventionForCall : public OpConversionPattern<CallOp> {
public:
AdjustCallingConventionForCall(TypeConverter &converter, MLIRContext *context,
TypeBoundMap &typeBoundMap)
: OpConversionPattern<CallOp>(converter, context),
typeBoundMap(typeBoundMap) {}
LogicalResult
matchAndRewrite(CallOp call, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> convertedResults;
if (failed(typeConverter->convertTypes(call.getResultTypes(),
convertedResults)))
return failure();
SmallVector<Value> newOperands;
for (auto operand : llvm::enumerate(operands)) {
if (operand.value().getType().isa<Basicpy::NoneType>())
continue;
auto it = typeBoundMap.find({call.callee(), operand.index()});
if (it != typeBoundMap.end()) {
newOperands.push_back(rewriter.create<Numpy::StaticInfoCastOp>(
call.getLoc(), it->second, operand.value()));
continue;
}
newOperands.push_back(operand.value());
}
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.callee(),
convertedResults, newOperands);
int newOpResultIdx = 0;
SmallVector<Value> newResults;
for (auto type : call.getResultTypes()) {
if (type.isa<Basicpy::NoneType>()) {
newResults.push_back(
rewriter.create<Basicpy::SingletonOp>(call.getLoc(), type));
continue;
}
newResults.push_back(newCall.getResult(newOpResultIdx++));
}
rewriter.replaceOp(call, newResults);
return success();
}
private:
TypeBoundMap &typeBoundMap;
};
} // namespace
namespace {
class AdjustCallingConventionForReturn : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newOperands;
for (auto operand : llvm::enumerate(operands)) {
if (!operand.value())
continue;
if (operand.value().getType().isa<Basicpy::NoneType>())
continue;
newOperands.push_back(operand.value());
}
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
};
} // namespace
static LogicalResult adjustCallingConventions(FuncOp func,
TypeBoundMap &typeBoundMap) {
MLIRContext *context = func.getContext();
RewritePatternSet patterns(context);
// TODO: TupleTypes
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](Basicpy::NoneType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
return success();
});
typeConverter.addArgumentMaterialization(
[](OpBuilder &builder, Numpy::NdArrayType type, ValueRange inputs,
Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Numpy::NdArrayType>());
return builder.create<Numpy::StaticInfoCastOp>(loc, type, inputs[0]);
});
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);
patterns.add<AdjustCallingConventionForCall>(typeConverter, context,
typeBoundMap);
patterns.add<AdjustCallingConventionForReturn>(typeConverter, context);
ConversionTarget target(*context);
target.addDynamicallyLegalOp<FuncOp>([](FuncOp func) {
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
if (func.getArgAttr(i, "torch.type_bound"))
return false;
if (func.getArgumentTypes()[i].isa<Basicpy::NoneType>())
return false;
}
for (int i = 0, e = func.getNumResults(); i != e; i++) {
if (func.getType().getResults()[i].isa<Basicpy::NoneType>())
return false;
}
return true;
});
// The dynamic legality conditions for call and return are a pain to write...
// Just run the patterns once and call it a day.
//
// Bug for doing this better https://bugs.llvm.org/show_bug.cgi?id=49812
DenseSet<Operation *> opsInOriginalProgram;
func.walk([&](CallOp op) { opsInOriginalProgram.insert(op.getOperation()); });
func.walk(
[&](ReturnOp op) { opsInOriginalProgram.insert(op.getOperation()); });
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
return !opsInOriginalProgram.contains(op.getOperation());
});
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
return !opsInOriginalProgram.contains(op.getOperation());
});
target.addLegalOp<Numpy::StaticInfoCastOp>();
target.addLegalOp<Basicpy::SingletonOp>();
// We don't know how to rewrite it, so mark it as illegal.
target.addIllegalOp<CallIndirectOp>();
if (failed(applyPartialConversion(func.getOperation(), target,
std::move(patterns))))
return failure();
return success();
}
namespace {
class AdjustCallingConventionsPass
: public AdjustCallingConventionsBase<AdjustCallingConventionsPass> {
void runOnOperation() override {
auto module = getOperation();
TypeBoundMap typeBoundMap;
for (auto func : module.getOps<FuncOp>()) {
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound");
if (!typeBoundAttr)
continue;
typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue();
}
}
for (auto func : module.getOps<FuncOp>()) {
if (failed(adjustCallingConventions(func, typeBoundMap)))
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() {
return std::make_unique<AdjustCallingConventionsPass>();
}

View File

@ -1,4 +1,5 @@
add_npcomp_conversion_library(NPCOMPTorchPasses add_npcomp_conversion_library(NPCOMPTorchPasses
AdjustCallingConventions.cpp
Passes.cpp Passes.cpp
GlobalizeObjectGraph.cpp GlobalizeObjectGraph.cpp
PrepareForGlobalizeObjectGraph.cpp PrepareForGlobalizeObjectGraph.cpp

View File

@ -27,6 +27,7 @@ OBJECT_GRAPH_LOWERING_PASSES = (
# bothersome because we don't currently have a lowering for them. # bothersome because we don't currently have a lowering for them.
# TODO: Support global slots in backends. # TODO: Support global slots in backends.
"symbol-dce", "symbol-dce",
"torch-adjust-calling-conventions",
) )
TORCH_TO_TCF_PASSES = ( TORCH_TO_TCF_PASSES = (

View File

@ -17,3 +17,14 @@ func @ndarray_tensor_bridging(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.nd
numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32> numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32>
return return
} }
// CHECK-LABEL: @static_info_cast
func @static_info_cast(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[?,3]:f32>, %arg2: !numpy.ndarray<*:f32>) {
// CHECK-NEXT: numpy.static_info_cast %arg0 : !numpy.ndarray<[2,3]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
%0 = numpy.static_info_cast %arg0 : !numpy.ndarray<[2,3]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
// CHECK-NEXT: numpy.static_info_cast %arg1 : !numpy.ndarray<[?,3]:f32> to !numpy.ndarray<[7,3]:f32>
%1 = numpy.static_info_cast %arg1 : !numpy.ndarray<[?,3]:f32> to !numpy.ndarray<[7,3]:f32>
// CHECK-NEXT: numpy.static_info_cast %arg2 : !numpy.ndarray<*:f32> to !numpy.ndarray<[?,?]:f32>
%2 = numpy.static_info_cast %arg2 : !numpy.ndarray<*:f32> to !numpy.ndarray<[?,?]:f32>
return
}

View File

@ -0,0 +1,46 @@
// RUN: npcomp-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
// CHECK: %[[RET:.*]] = numpy.static_info_cast %[[ARG]] : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
// CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
func @basic(%arg0: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[2,3,?]:f32>}) -> !numpy.ndarray<*:!numpy.any_dtype> {
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
}
// CHECK-LABEL: func @no_type_bound(
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
// CHECK: return %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype>
func @no_type_bound(%arg0: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
}
// CHECK-LABEL: func @call(
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.static_info_cast %[[ARG]] : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
// CHECK: %[[SHAPED:.*]] = numpy.static_info_cast %[[SHAPE_ERASED]] : !numpy.ndarray<*:!numpy.any_dtype> to !numpy.ndarray<[2,3,?]:f32>
// CHECK: %[[RES:.*]] = call @call(%[[SHAPED]]) : (!numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype>
// CHECK: return %[[SHAPE_ERASED]] : !numpy.ndarray<*:!numpy.any_dtype>
func @call(%arg0: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[2,3,?]:f32>}) -> !numpy.ndarray<*:!numpy.any_dtype> {
%0 = call @call(%arg0) : (!numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype>
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
}
// CHECK-LABEL: func @none_return() {
// CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
// CHECK: return
func @none_return() -> !basicpy.NoneType {
%1 = basicpy.singleton : !basicpy.NoneType
return %1 : !basicpy.NoneType
}
// CHECK-LABEL: func @none_call_return() {
// CHECK: call @none_return() : () -> ()
// CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
// CHECK: "test.use"(%[[NONE]]) : (!basicpy.NoneType) -> ()
// CHECK: return
func @none_call_return() {
%0 = call @none_return() : () -> !basicpy.NoneType
"test.use"(%0) : (!basicpy.NoneType) -> ()
return
}