mirror of https://github.com/llvm/torch-mlir
Add torch-adjust-calling-conventions pass.
This pass incorporates torch.type_bound info and also removes NoneType returns (eventually it will rewrite tuple types too, but can't yet because !basicpy.TupleType doesn't track element types). Recommend looking at adjust-calling-conventions.mlir first to see what it is doing, and holding your nose for the implementation of the pass. I decided to implement this with the conversion framework, because it gives us *some* goodies for type conversion -- mainly avoiding large amounts of tricky RAUW dances. Unfortunately, the conversion framework isn't a perfect fit for a couple reasons: - the incorporation of torch.type_bound is a context-sensitive rewrite (requires looking at the arg attr, not just the type). - NoneType conversion is 1->0, which requires some special handling - (not implemented yet) 1->N tuple type conversions require special handling. It's a little bit scary, but on balance doing it the other way would have its own downsides.pull/202/head
parent
464feacba9
commit
30356c41c8
|
@ -15,6 +15,7 @@
|
||||||
#include "mlir/IR/FunctionSupport.h"
|
#include "mlir/IR/FunctionSupport.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
#include "mlir/Interfaces/CastInterfaces.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
|
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
|
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
|
||||||
include "npcomp/Typing/Analysis/CPA/Interfaces.td"
|
include "npcomp/Typing/Analysis/CPA/Interfaces.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "mlir/Interfaces/CastInterfaces.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
|
@ -35,6 +36,25 @@ def Numpy_NarrowOp : Numpy_Op<"narrow", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
|
||||||
|
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||||
|
NoSideEffect]> {
|
||||||
|
let summary = "Adds/removes static information from an array type.";
|
||||||
|
let description = [{
|
||||||
|
This op does not imply any runtime code. Semantically it is an identity
|
||||||
|
function.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Numpy_AnyArray:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Numpy_AnyArray:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` type($operand) `to` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
// NdArray type handling
|
// NdArray type handling
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
|
|
|
@ -26,6 +26,8 @@ createPrepareForGlobalizeObjectGraphPass();
|
||||||
/// See the documentation on torch-globalize-object-graph for more details.
|
/// See the documentation on torch-globalize-object-graph for more details.
|
||||||
void createGlobalizePipeline(OpPassManager &pm);
|
void createGlobalizePipeline(OpPassManager &pm);
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
|
|
||||||
/// Registers all Torch transformation passes.
|
/// Registers all Torch transformation passes.
|
||||||
|
|
|
@ -101,4 +101,28 @@ def PrepareForGlobalizeObjectGraph
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def AdjustCallingConventions
|
||||||
|
: Pass<"torch-adjust-calling-conventions", "ModuleOp"> {
|
||||||
|
let summary = "Adjust the calling conventions of functions";
|
||||||
|
let constructor = "mlir::NPCOMP::Torch::createAdjustCallingConventionsPass()";
|
||||||
|
let description = [{
|
||||||
|
Adjusts the calling conventions of functions in the module, with the aim of
|
||||||
|
preparing them for backends and further lowering passes. As this changes
|
||||||
|
the module calling convention, it should be considered a legalization
|
||||||
|
step towards reaching IR that is suitable for an appropriate backend.
|
||||||
|
All transformations are context-free and suitable for documenting
|
||||||
|
at the user level if needed to clarify the eventual calling convention
|
||||||
|
of compiled artifacts.
|
||||||
|
This is not an optimization.
|
||||||
|
|
||||||
|
The transformations performed are:
|
||||||
|
- `torch.type_bound` annotations are incorporated into the type of the
|
||||||
|
function arguments, which should be `!numpy.ndarray<...>`'s.
|
||||||
|
- Python-isms are rewritten to MLIR-isms
|
||||||
|
- NoneType return is rewritten to the absence of a return value.
|
||||||
|
- (Not implemented yet) Tuple return is rewritten to multiple return
|
||||||
|
values
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // NPCOMP_TORCH_PASSES
|
#endif // NPCOMP_TORCH_PASSES
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include "mlir/IR/FunctionImplementation.h"
|
#include "mlir/IR/FunctionImplementation.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
|
||||||
|
@ -58,6 +59,24 @@ void BuiltinUfuncCallOp::addCPAConstraints(Typing::CPA::Context &context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
// StaticInfoCast
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
|
||||||
|
bool StaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
||||||
|
mlir::TypeRange outputs) {
|
||||||
|
auto input = inputs[0].cast<NdArrayType>();
|
||||||
|
auto output = outputs[0].cast<NdArrayType>();
|
||||||
|
if (input.getOptionalShape() && output.getOptionalShape()) {
|
||||||
|
if (failed(verifyCompatibleShape(*input.getOptionalShape(),
|
||||||
|
*output.getOptionalShape())))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return input.getDtype() == output.getDtype() ||
|
||||||
|
input.getDtype().isa<AnyDtypeType>() ||
|
||||||
|
output.getDtype().isa<AnyDtypeType>();
|
||||||
|
}
|
||||||
|
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
// CreateArrayFromTensorOp
|
// CreateArrayFromTensorOp
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
|
|
|
@ -0,0 +1,250 @@
|
||||||
|
//===- AdjustCallingConventions.cpp ------------------------*- C++-*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
|
// Map from func name and arg index to the type bound for that arg.
|
||||||
|
// This is needed because to rewrite calls, we need the non-local information
|
||||||
|
// from the func definition.
|
||||||
|
// We also benefit from populating this all at once, which avoids ordering
|
||||||
|
// issues between rewriting of func ops vs call ops.
|
||||||
|
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type> ;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(FuncOp func, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
MLIRContext *context = func.getContext();
|
||||||
|
auto typeBoundIdent = Identifier::get("torch.type_bound", context);
|
||||||
|
TypeConverter::SignatureConversion conversion(func.getNumArguments());
|
||||||
|
|
||||||
|
// The TypeConverter hooks for type conversion are "context free", so we
|
||||||
|
// cannot use the usual helpers here for populating SignatureConversion and
|
||||||
|
// new result types.
|
||||||
|
//
|
||||||
|
// The incoporation of the torch.type_bound arg attr is context-dependent.
|
||||||
|
|
||||||
|
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
||||||
|
if (auto ndarray = type.value().dyn_cast<Numpy::NdArrayType>()) {
|
||||||
|
auto typeBoundAttr =
|
||||||
|
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
|
||||||
|
conversion.addInputs(type.index(), typeBoundAttr
|
||||||
|
? typeBoundAttr.getValue()
|
||||||
|
: type.value());
|
||||||
|
continue;
|
||||||
|
// type is attached to ndarray type.
|
||||||
|
// TODO: check if more specific?
|
||||||
|
} else if (auto none = type.value().dyn_cast<Basicpy::NoneType>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// TODO: add tuple type.
|
||||||
|
conversion.addInputs(type.index(), type.value());
|
||||||
|
}
|
||||||
|
SmallVector<Type> newResultTypes;
|
||||||
|
for (auto type : func.getType().getResults()) {
|
||||||
|
if (auto none = type.dyn_cast<Basicpy::NoneType>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newResultTypes.push_back(type);
|
||||||
|
}
|
||||||
|
rewriter.applySignatureConversion(&func.getBody(), conversion,
|
||||||
|
typeConverter);
|
||||||
|
rewriter.updateRootInPlace(func, [&] {
|
||||||
|
func.setType(FunctionType::get(
|
||||||
|
getContext(), conversion.getConvertedTypes(), newResultTypes));
|
||||||
|
// Clear out the type bounds, now that the type incorporates them.
|
||||||
|
for (int i = 0, e = func.getNumArguments(); i != e; i++)
|
||||||
|
func.removeArgAttr(i, typeBoundIdent);
|
||||||
|
});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class AdjustCallingConventionForCall : public OpConversionPattern<CallOp> {
|
||||||
|
public:
|
||||||
|
AdjustCallingConventionForCall(TypeConverter &converter, MLIRContext *context,
|
||||||
|
TypeBoundMap &typeBoundMap)
|
||||||
|
: OpConversionPattern<CallOp>(converter, context),
|
||||||
|
typeBoundMap(typeBoundMap) {}
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(CallOp call, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Type> convertedResults;
|
||||||
|
if (failed(typeConverter->convertTypes(call.getResultTypes(),
|
||||||
|
convertedResults)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Value> newOperands;
|
||||||
|
for (auto operand : llvm::enumerate(operands)) {
|
||||||
|
if (operand.value().getType().isa<Basicpy::NoneType>())
|
||||||
|
continue;
|
||||||
|
auto it = typeBoundMap.find({call.callee(), operand.index()});
|
||||||
|
if (it != typeBoundMap.end()) {
|
||||||
|
newOperands.push_back(rewriter.create<Numpy::StaticInfoCastOp>(
|
||||||
|
call.getLoc(), it->second, operand.value()));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newOperands.push_back(operand.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.callee(),
|
||||||
|
convertedResults, newOperands);
|
||||||
|
int newOpResultIdx = 0;
|
||||||
|
SmallVector<Value> newResults;
|
||||||
|
for (auto type : call.getResultTypes()) {
|
||||||
|
if (type.isa<Basicpy::NoneType>()) {
|
||||||
|
newResults.push_back(
|
||||||
|
rewriter.create<Basicpy::SingletonOp>(call.getLoc(), type));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newResults.push_back(newCall.getResult(newOpResultIdx++));
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(call, newResults);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TypeBoundMap &typeBoundMap;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class AdjustCallingConventionForReturn : public OpConversionPattern<ReturnOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
SmallVector<Value> newOperands;
|
||||||
|
for (auto operand : llvm::enumerate(operands)) {
|
||||||
|
if (!operand.value())
|
||||||
|
continue;
|
||||||
|
if (operand.value().getType().isa<Basicpy::NoneType>())
|
||||||
|
continue;
|
||||||
|
newOperands.push_back(operand.value());
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
static LogicalResult adjustCallingConventions(FuncOp func,
|
||||||
|
TypeBoundMap &typeBoundMap) {
|
||||||
|
MLIRContext *context = func.getContext();
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
// TODO: TupleTypes
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
typeConverter.addConversion(
|
||||||
|
[](Basicpy::NoneType type,
|
||||||
|
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
|
||||||
|
typeConverter.addArgumentMaterialization(
|
||||||
|
[](OpBuilder &builder, Numpy::NdArrayType type, ValueRange inputs,
|
||||||
|
Location loc) -> Value {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<Numpy::NdArrayType>());
|
||||||
|
return builder.create<Numpy::StaticInfoCastOp>(loc, type, inputs[0]);
|
||||||
|
});
|
||||||
|
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);
|
||||||
|
patterns.add<AdjustCallingConventionForCall>(typeConverter, context,
|
||||||
|
typeBoundMap);
|
||||||
|
patterns.add<AdjustCallingConventionForReturn>(typeConverter, context);
|
||||||
|
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
target.addDynamicallyLegalOp<FuncOp>([](FuncOp func) {
|
||||||
|
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
||||||
|
if (func.getArgAttr(i, "torch.type_bound"))
|
||||||
|
return false;
|
||||||
|
if (func.getArgumentTypes()[i].isa<Basicpy::NoneType>())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0, e = func.getNumResults(); i != e; i++) {
|
||||||
|
if (func.getType().getResults()[i].isa<Basicpy::NoneType>())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
// The dynamic legality conditions for call and return are a pain to write...
|
||||||
|
// Just run the patterns once and call it a day.
|
||||||
|
//
|
||||||
|
// Bug for doing this better https://bugs.llvm.org/show_bug.cgi?id=49812
|
||||||
|
DenseSet<Operation *> opsInOriginalProgram;
|
||||||
|
func.walk([&](CallOp op) { opsInOriginalProgram.insert(op.getOperation()); });
|
||||||
|
func.walk(
|
||||||
|
[&](ReturnOp op) { opsInOriginalProgram.insert(op.getOperation()); });
|
||||||
|
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
|
||||||
|
return !opsInOriginalProgram.contains(op.getOperation());
|
||||||
|
});
|
||||||
|
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
|
||||||
|
return !opsInOriginalProgram.contains(op.getOperation());
|
||||||
|
});
|
||||||
|
target.addLegalOp<Numpy::StaticInfoCastOp>();
|
||||||
|
target.addLegalOp<Basicpy::SingletonOp>();
|
||||||
|
// We don't know how to rewrite it, so mark it as illegal.
|
||||||
|
target.addIllegalOp<CallIndirectOp>();
|
||||||
|
if (failed(applyPartialConversion(func.getOperation(), target,
|
||||||
|
std::move(patterns))))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class AdjustCallingConventionsPass
|
||||||
|
: public AdjustCallingConventionsBase<AdjustCallingConventionsPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto module = getOperation();
|
||||||
|
TypeBoundMap typeBoundMap;
|
||||||
|
for (auto func : module.getOps<FuncOp>()) {
|
||||||
|
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
||||||
|
auto typeBoundAttr =
|
||||||
|
func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound");
|
||||||
|
if (!typeBoundAttr)
|
||||||
|
continue;
|
||||||
|
typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto func : module.getOps<FuncOp>()) {
|
||||||
|
if (failed(adjustCallingConventions(func, typeBoundMap)))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() {
|
||||||
|
return std::make_unique<AdjustCallingConventionsPass>();
|
||||||
|
}
|
|
@ -1,4 +1,5 @@
|
||||||
add_npcomp_conversion_library(NPCOMPTorchPasses
|
add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||||
|
AdjustCallingConventions.cpp
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
GlobalizeObjectGraph.cpp
|
GlobalizeObjectGraph.cpp
|
||||||
PrepareForGlobalizeObjectGraph.cpp
|
PrepareForGlobalizeObjectGraph.cpp
|
||||||
|
|
|
@ -27,6 +27,7 @@ OBJECT_GRAPH_LOWERING_PASSES = (
|
||||||
# bothersome because we don't currently have a lowering for them.
|
# bothersome because we don't currently have a lowering for them.
|
||||||
# TODO: Support global slots in backends.
|
# TODO: Support global slots in backends.
|
||||||
"symbol-dce",
|
"symbol-dce",
|
||||||
|
"torch-adjust-calling-conventions",
|
||||||
)
|
)
|
||||||
|
|
||||||
TORCH_TO_TCF_PASSES = (
|
TORCH_TO_TCF_PASSES = (
|
||||||
|
|
|
@ -17,3 +17,14 @@ func @ndarray_tensor_bridging(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.nd
|
||||||
numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32>
|
numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @static_info_cast
|
||||||
|
func @static_info_cast(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[?,3]:f32>, %arg2: !numpy.ndarray<*:f32>) {
|
||||||
|
// CHECK-NEXT: numpy.static_info_cast %arg0 : !numpy.ndarray<[2,3]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
%0 = numpy.static_info_cast %arg0 : !numpy.ndarray<[2,3]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
// CHECK-NEXT: numpy.static_info_cast %arg1 : !numpy.ndarray<[?,3]:f32> to !numpy.ndarray<[7,3]:f32>
|
||||||
|
%1 = numpy.static_info_cast %arg1 : !numpy.ndarray<[?,3]:f32> to !numpy.ndarray<[7,3]:f32>
|
||||||
|
// CHECK-NEXT: numpy.static_info_cast %arg2 : !numpy.ndarray<*:f32> to !numpy.ndarray<[?,?]:f32>
|
||||||
|
%2 = numpy.static_info_cast %arg2 : !numpy.ndarray<*:f32> to !numpy.ndarray<[?,?]:f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
// RUN: npcomp-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @basic(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
// CHECK: %[[RET:.*]] = numpy.static_info_cast %[[ARG]] : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
// CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
func @basic(%arg0: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[2,3,?]:f32>}) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @no_type_bound(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
// CHECK: return %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
func @no_type_bound(%arg0: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @call(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.static_info_cast %[[ARG]] : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
// CHECK: %[[SHAPED:.*]] = numpy.static_info_cast %[[SHAPE_ERASED]] : !numpy.ndarray<*:!numpy.any_dtype> to !numpy.ndarray<[2,3,?]:f32>
|
||||||
|
// CHECK: %[[RES:.*]] = call @call(%[[SHAPED]]) : (!numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
// CHECK: return %[[SHAPE_ERASED]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
func @call(%arg0: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[2,3,?]:f32>}) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
%0 = call @call(%arg0) : (!numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @none_return() {
|
||||||
|
// CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||||
|
// CHECK: return
|
||||||
|
func @none_return() -> !basicpy.NoneType {
|
||||||
|
%1 = basicpy.singleton : !basicpy.NoneType
|
||||||
|
return %1 : !basicpy.NoneType
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @none_call_return() {
|
||||||
|
// CHECK: call @none_return() : () -> ()
|
||||||
|
// CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||||
|
// CHECK: "test.use"(%[[NONE]]) : (!basicpy.NoneType) -> ()
|
||||||
|
// CHECK: return
|
||||||
|
func @none_call_return() {
|
||||||
|
%0 = call @none_return() : () -> !basicpy.NoneType
|
||||||
|
"test.use"(%0) : (!basicpy.NoneType) -> ()
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue