diff --git a/include/npcomp/Dialect/Numpy/IR/NumpyOps.h b/include/npcomp/Dialect/Numpy/IR/NumpyOps.h index 3700274c5..701841143 100644 --- a/include/npcomp/Dialect/Numpy/IR/NumpyOps.h +++ b/include/npcomp/Dialect/Numpy/IR/NumpyOps.h @@ -15,6 +15,7 @@ #include "mlir/IR/FunctionSupport.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "npcomp/Typing/Analysis/CPA/Interfaces.h" diff --git a/include/npcomp/Dialect/Numpy/IR/NumpyOps.td b/include/npcomp/Dialect/Numpy/IR/NumpyOps.td index 06093e80d..caa795b8f 100644 --- a/include/npcomp/Dialect/Numpy/IR/NumpyOps.td +++ b/include/npcomp/Dialect/Numpy/IR/NumpyOps.td @@ -12,6 +12,7 @@ include "npcomp/Dialect/Numpy/IR/NumpyDialect.td" include "npcomp/Typing/Analysis/CPA/Interfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/IR/SymbolInterfaces.td" //----------------------------------------------------------------------------// @@ -35,6 +36,25 @@ def Numpy_NarrowOp : Numpy_Op<"narrow", []> { }]; } +def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { + let summary = "Adds/removes static information from an array type."; + let description = [{ + This op does not imply any runtime code. Semantically it is an identity + function. + }]; + let arguments = (ins + Numpy_AnyArray:$operand + ); + let results = (outs + Numpy_AnyArray:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; +} + //----------------------------------------------------------------------------// // NdArray type handling //----------------------------------------------------------------------------// diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.h b/include/npcomp/Dialect/Torch/Transforms/Passes.h index 3d7dbbe01..a38df95d4 100644 --- a/include/npcomp/Dialect/Torch/Transforms/Passes.h +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.h @@ -26,6 +26,8 @@ createPrepareForGlobalizeObjectGraphPass(); /// See the documentation on torch-globalize-object-graph for more details. void createGlobalizePipeline(OpPassManager &pm); +std::unique_ptr> createAdjustCallingConventionsPass(); + } // namespace Torch /// Registers all Torch transformation passes. diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.td b/include/npcomp/Dialect/Torch/Transforms/Passes.td index e127bbe5c..8c6dcf293 100644 --- a/include/npcomp/Dialect/Torch/Transforms/Passes.td +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.td @@ -101,4 +101,28 @@ def PrepareForGlobalizeObjectGraph }]; } +def AdjustCallingConventions + : Pass<"torch-adjust-calling-conventions", "ModuleOp"> { + let summary = "Adjust the calling conventions of functions"; + let constructor = "mlir::NPCOMP::Torch::createAdjustCallingConventionsPass()"; + let description = [{ + Adjusts the calling conventions of functions in the module, with the aim of + preparing them for backends and further lowering passes. As this changes + the module calling convention, it should be considered a legalization + step towards reaching IR that is suitable for an appropriate backend. + All transformations are context-free and suitable for documenting + at the user level if needed to clarify the eventual calling convention + of compiled artifacts. + This is not an optimization. + + The transformations performed are: + - `torch.type_bound` annotations are incorporated into the type of the + function arguments, which should be `!numpy.ndarray<...>`'s. + - Python-isms are rewritten to MLIR-isms + - NoneType return is rewritten to the absence of a return value. + - (Not implemented yet) Tuple return is rewritten to multiple return + values + }]; +} + #endif // NPCOMP_TORCH_PASSES diff --git a/lib/Dialect/Numpy/IR/NumpyOps.cpp b/lib/Dialect/Numpy/IR/NumpyOps.cpp index 46683f555..69be8f04a 100644 --- a/lib/Dialect/Numpy/IR/NumpyOps.cpp +++ b/lib/Dialect/Numpy/IR/NumpyOps.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" @@ -58,6 +59,24 @@ void BuiltinUfuncCallOp::addCPAConstraints(Typing::CPA::Context &context) { } } +//----------------------------------------------------------------------------// +// StaticInfoCast +//----------------------------------------------------------------------------// + +bool StaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs, + mlir::TypeRange outputs) { + auto input = inputs[0].cast(); + auto output = outputs[0].cast(); + if (input.getOptionalShape() && output.getOptionalShape()) { + if (failed(verifyCompatibleShape(*input.getOptionalShape(), + *output.getOptionalShape()))) + return false; + } + return input.getDtype() == output.getDtype() || + input.getDtype().isa() || + output.getDtype().isa(); +} + //----------------------------------------------------------------------------// // CreateArrayFromTensorOp //----------------------------------------------------------------------------// diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp new file mode 100644 index 000000000..dadce74db --- /dev/null +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -0,0 +1,250 @@ +//===- AdjustCallingConventions.cpp ------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" +#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" +#include "npcomp/Dialect/Numpy/IR/NumpyOps.h" +#include "npcomp/Dialect/Torch/IR/TorchDialect.h" +#include "npcomp/Dialect/Torch/IR/TorchOps.h" +#include "npcomp/Dialect/Torch/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::Torch; + +// Map from func name and arg index to the type bound for that arg. +// This is needed because to rewrite calls, we need the non-local information +// from the func definition. +// We also benefit from populating this all at once, which avoids ordering +// issues between rewriting of func ops vs call ops. +using TypeBoundMap = DenseMap, Type> ; + +namespace { +class AdjustCallingConventionForFunc : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(FuncOp func, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *context = func.getContext(); + auto typeBoundIdent = Identifier::get("torch.type_bound", context); + TypeConverter::SignatureConversion conversion(func.getNumArguments()); + + // The TypeConverter hooks for type conversion are "context free", so we + // cannot use the usual helpers here for populating SignatureConversion and + // new result types. + // + // The incoporation of the torch.type_bound arg attr is context-dependent. + + for (auto type : llvm::enumerate(func.getArgumentTypes())) { + if (auto ndarray = type.value().dyn_cast()) { + auto typeBoundAttr = + func.getArgAttrOfType(type.index(), typeBoundIdent); + conversion.addInputs(type.index(), typeBoundAttr + ? typeBoundAttr.getValue() + : type.value()); + continue; + // type is attached to ndarray type. + // TODO: check if more specific? + } else if (auto none = type.value().dyn_cast()) { + continue; + } + // TODO: add tuple type. + conversion.addInputs(type.index(), type.value()); + } + SmallVector newResultTypes; + for (auto type : func.getType().getResults()) { + if (auto none = type.dyn_cast()) { + continue; + } + newResultTypes.push_back(type); + } + rewriter.applySignatureConversion(&func.getBody(), conversion, + typeConverter); + rewriter.updateRootInPlace(func, [&] { + func.setType(FunctionType::get( + getContext(), conversion.getConvertedTypes(), newResultTypes)); + // Clear out the type bounds, now that the type incorporates them. + for (int i = 0, e = func.getNumArguments(); i != e; i++) + func.removeArgAttr(i, typeBoundIdent); + }); + return success(); + } +}; +} // namespace + +namespace { +class AdjustCallingConventionForCall : public OpConversionPattern { +public: + AdjustCallingConventionForCall(TypeConverter &converter, MLIRContext *context, + TypeBoundMap &typeBoundMap) + : OpConversionPattern(converter, context), + typeBoundMap(typeBoundMap) {} + LogicalResult + matchAndRewrite(CallOp call, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector convertedResults; + if (failed(typeConverter->convertTypes(call.getResultTypes(), + convertedResults))) + return failure(); + + SmallVector newOperands; + for (auto operand : llvm::enumerate(operands)) { + if (operand.value().getType().isa()) + continue; + auto it = typeBoundMap.find({call.callee(), operand.index()}); + if (it != typeBoundMap.end()) { + newOperands.push_back(rewriter.create( + call.getLoc(), it->second, operand.value())); + continue; + } + newOperands.push_back(operand.value()); + } + + CallOp newCall = rewriter.create(call.getLoc(), call.callee(), + convertedResults, newOperands); + int newOpResultIdx = 0; + SmallVector newResults; + for (auto type : call.getResultTypes()) { + if (type.isa()) { + newResults.push_back( + rewriter.create(call.getLoc(), type)); + continue; + } + newResults.push_back(newCall.getResult(newOpResultIdx++)); + } + rewriter.replaceOp(call, newResults); + return success(); + } + + private: + TypeBoundMap &typeBoundMap; +}; +} // namespace + +namespace { +class AdjustCallingConventionForReturn : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ReturnOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + SmallVector newOperands; + for (auto operand : llvm::enumerate(operands)) { + if (!operand.value()) + continue; + if (operand.value().getType().isa()) + continue; + newOperands.push_back(operand.value()); + } + rewriter.replaceOpWithNewOp(op, newOperands); + return success(); + } +}; +} // namespace + +static LogicalResult adjustCallingConventions(FuncOp func, + TypeBoundMap &typeBoundMap) { + MLIRContext *context = func.getContext(); + RewritePatternSet patterns(context); + // TODO: TupleTypes + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](Basicpy::NoneType type, + SmallVectorImpl &types) -> Optional { + return success(); + }); + + typeConverter.addArgumentMaterialization( + [](OpBuilder &builder, Numpy::NdArrayType type, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); + }); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context, + typeBoundMap); + patterns.add(typeConverter, context); + + ConversionTarget target(*context); + target.addDynamicallyLegalOp([](FuncOp func) { + for (int i = 0, e = func.getNumArguments(); i != e; i++) { + if (func.getArgAttr(i, "torch.type_bound")) + return false; + if (func.getArgumentTypes()[i].isa()) + return false; + } + for (int i = 0, e = func.getNumResults(); i != e; i++) { + if (func.getType().getResults()[i].isa()) + return false; + } + return true; + }); + // The dynamic legality conditions for call and return are a pain to write... + // Just run the patterns once and call it a day. + // + // Bug for doing this better https://bugs.llvm.org/show_bug.cgi?id=49812 + DenseSet opsInOriginalProgram; + func.walk([&](CallOp op) { opsInOriginalProgram.insert(op.getOperation()); }); + func.walk( + [&](ReturnOp op) { opsInOriginalProgram.insert(op.getOperation()); }); + target.addDynamicallyLegalOp([&](CallOp op) { + return !opsInOriginalProgram.contains(op.getOperation()); + }); + target.addDynamicallyLegalOp([&](ReturnOp op) { + return !opsInOriginalProgram.contains(op.getOperation()); + }); + target.addLegalOp(); + target.addLegalOp(); + // We don't know how to rewrite it, so mark it as illegal. + target.addIllegalOp(); + if (failed(applyPartialConversion(func.getOperation(), target, + std::move(patterns)))) + return failure(); + return success(); +} + +namespace { +class AdjustCallingConventionsPass + : public AdjustCallingConventionsBase { + void runOnOperation() override { + auto module = getOperation(); + TypeBoundMap typeBoundMap; + for (auto func : module.getOps()) { + for (int i = 0, e = func.getNumArguments(); i != e; i++) { + auto typeBoundAttr = + func.getArgAttrOfType(i, "torch.type_bound"); + if (!typeBoundAttr) + continue; + typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue(); + } + } + for (auto func : module.getOps()) { + if (failed(adjustCallingConventions(func, typeBoundMap))) + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index 28508b898..7349d9a7e 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_npcomp_conversion_library(NPCOMPTorchPasses + AdjustCallingConventions.cpp Passes.cpp GlobalizeObjectGraph.cpp PrepareForGlobalizeObjectGraph.cpp diff --git a/python/npcomp/compiler/pytorch/backend/refjit.py b/python/npcomp/compiler/pytorch/backend/refjit.py index 49781ce71..fc2cc03c5 100644 --- a/python/npcomp/compiler/pytorch/backend/refjit.py +++ b/python/npcomp/compiler/pytorch/backend/refjit.py @@ -27,6 +27,7 @@ OBJECT_GRAPH_LOWERING_PASSES = ( # bothersome because we don't currently have a lowering for them. # TODO: Support global slots in backends. "symbol-dce", + "torch-adjust-calling-conventions", ) TORCH_TO_TCF_PASSES = ( diff --git a/test/Dialect/Numpy/ops.mlir b/test/Dialect/Numpy/ops.mlir index 89f4b0e44..5db4c3d2c 100644 --- a/test/Dialect/Numpy/ops.mlir +++ b/test/Dialect/Numpy/ops.mlir @@ -17,3 +17,14 @@ func @ndarray_tensor_bridging(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.nd numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32> return } + +// CHECK-LABEL: @static_info_cast +func @static_info_cast(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[?,3]:f32>, %arg2: !numpy.ndarray<*:f32>) { + // CHECK-NEXT: numpy.static_info_cast %arg0 : !numpy.ndarray<[2,3]:f32> to !numpy.ndarray<*:!numpy.any_dtype> + %0 = numpy.static_info_cast %arg0 : !numpy.ndarray<[2,3]:f32> to !numpy.ndarray<*:!numpy.any_dtype> + // CHECK-NEXT: numpy.static_info_cast %arg1 : !numpy.ndarray<[?,3]:f32> to !numpy.ndarray<[7,3]:f32> + %1 = numpy.static_info_cast %arg1 : !numpy.ndarray<[?,3]:f32> to !numpy.ndarray<[7,3]:f32> + // CHECK-NEXT: numpy.static_info_cast %arg2 : !numpy.ndarray<*:f32> to !numpy.ndarray<[?,?]:f32> + %2 = numpy.static_info_cast %arg2 : !numpy.ndarray<*:f32> to !numpy.ndarray<[?,?]:f32> + return +} diff --git a/test/Dialect/Torch/adjust-calling-conventions.mlir b/test/Dialect/Torch/adjust-calling-conventions.mlir new file mode 100644 index 000000000..74d2be6fc --- /dev/null +++ b/test/Dialect/Torch/adjust-calling-conventions.mlir @@ -0,0 +1,46 @@ +// RUN: npcomp-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @basic( +// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> { +// CHECK: %[[RET:.*]] = numpy.static_info_cast %[[ARG]] : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype> +// CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype> +func @basic(%arg0: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[2,3,?]:f32>}) -> !numpy.ndarray<*:!numpy.any_dtype> { + return %arg0 : !numpy.ndarray<*:!numpy.any_dtype> +} + +// CHECK-LABEL: func @no_type_bound( +// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> { +// CHECK: return %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> +func @no_type_bound(%arg0: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> { + return %arg0 : !numpy.ndarray<*:!numpy.any_dtype> +} + +// CHECK-LABEL: func @call( +// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> { +// CHECK: %[[SHAPE_ERASED:.*]] = numpy.static_info_cast %[[ARG]] : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype> +// CHECK: %[[SHAPED:.*]] = numpy.static_info_cast %[[SHAPE_ERASED]] : !numpy.ndarray<*:!numpy.any_dtype> to !numpy.ndarray<[2,3,?]:f32> +// CHECK: %[[RES:.*]] = call @call(%[[SHAPED]]) : (!numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> +// CHECK: return %[[SHAPE_ERASED]] : !numpy.ndarray<*:!numpy.any_dtype> +func @call(%arg0: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[2,3,?]:f32>}) -> !numpy.ndarray<*:!numpy.any_dtype> { + %0 = call @call(%arg0) : (!numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> + return %arg0 : !numpy.ndarray<*:!numpy.any_dtype> +} + +// CHECK-LABEL: func @none_return() { +// CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType +// CHECK: return +func @none_return() -> !basicpy.NoneType { + %1 = basicpy.singleton : !basicpy.NoneType + return %1 : !basicpy.NoneType +} + +// CHECK-LABEL: func @none_call_return() { +// CHECK: call @none_return() : () -> () +// CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType +// CHECK: "test.use"(%[[NONE]]) : (!basicpy.NoneType) -> () +// CHECK: return +func @none_call_return() { + %0 = call @none_return() : () -> !basicpy.NoneType + "test.use"(%0) : (!basicpy.NoneType) -> () + return +}