mirror of https://github.com/llvm/torch-mlir
Add flatten op recognition + shape refinement.
This op has complex aliasing semantics, so it is kept mutable for now. With this, we reduce ResNet18 to a single BB with all aten operators having rank + dtype: https://gist.github.com/silvasean/2fcb1c6e4d4ae27461204a43ae9c5031pull/214/head
parent
122cae2ee3
commit
3d08c83580
|
@ -164,6 +164,10 @@ def generate_ops(g: "OpGenerator"):
|
||||||
"aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
|
"aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
|
||||||
"NllLoss2dBackwardOp", "nll_loss2d_backward")
|
"NllLoss2dBackwardOp", "nll_loss2d_backward")
|
||||||
|
|
||||||
|
g.print_banner("Mutable/view-like ops")
|
||||||
|
g.ordinary_mutable_op("aten::flatten(Tensor,int,int)",
|
||||||
|
"FlattenOp",
|
||||||
|
"flatten")
|
||||||
# One-off in-place ops (note that many in-place arithmetic ops are handled
|
# One-off in-place ops (note that many in-place arithmetic ops are handled
|
||||||
# as a transformation from their immutable forms).
|
# as a transformation from their immutable forms).
|
||||||
g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)",
|
g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)",
|
||||||
|
@ -298,6 +302,36 @@ class OpGenerator:
|
||||||
)
|
)
|
||||||
opdef.emit()
|
opdef.emit()
|
||||||
|
|
||||||
|
def ordinary_mutable_op(self,
|
||||||
|
kernel_sig: str,
|
||||||
|
ods_name: str,
|
||||||
|
op_name: str,
|
||||||
|
traits: Sequence[str] = (),
|
||||||
|
**kwargs):
|
||||||
|
""""An ordinary mutable-tensor based op."""
|
||||||
|
opdef = self.define_op(
|
||||||
|
kernel_sig=kernel_sig,
|
||||||
|
ods_name=ods_name,
|
||||||
|
op_name=op_name,
|
||||||
|
traits=list(traits),
|
||||||
|
**kwargs)
|
||||||
|
opdef.transforms(
|
||||||
|
type_transforms={
|
||||||
|
"Tensor": "AnyTorchMutableTensor",
|
||||||
|
"Tensor?": "AnyTorchOptionalMutableTensor",
|
||||||
|
"int": "AnyTorchIntType",
|
||||||
|
"int[]": "AnyTorchIntListType",
|
||||||
|
"bool": "AnyTorchBoolType",
|
||||||
|
"bool[]": "AnyTorchBoolListType",
|
||||||
|
"float": "AnyFloat",
|
||||||
|
},
|
||||||
|
flag_transforms={
|
||||||
|
"Tensor": ["kMutableTensor"],
|
||||||
|
"Tensor?": ["kMutableTensor"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
opdef.emit()
|
||||||
|
|
||||||
def ordinary_primitive_op(self,
|
def ordinary_primitive_op(self,
|
||||||
kernel_sig: str,
|
kernel_sig: str,
|
||||||
ods_name: str,
|
ods_name: str,
|
||||||
|
|
|
@ -39,23 +39,6 @@ def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyType:$arg0,
|
|
||||||
AnyType:$arg1,
|
|
||||||
AnyType:$arg2
|
|
||||||
);
|
|
||||||
|
|
||||||
let summary = "Flatten operator";
|
|
||||||
let description = [{
|
|
||||||
Flatten operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>,
|
def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>,
|
||||||
Results<(outs AnyType)> {
|
Results<(outs AnyType)> {
|
||||||
let summary = "TypeCast operator";
|
let summary = "TypeCast operator";
|
||||||
|
|
|
@ -1291,6 +1291,28 @@ const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetada
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Mutable/view-like ops
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Torch::KernelMetadata FlattenOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &FlattenOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::flatten";
|
||||||
|
m.addArgTypes({"Tensor", "int", "int"});
|
||||||
|
m.addArgConversions({KVC::kMutableTensor, KVC::kNone, KVC::kNone});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kMutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
|
|
@ -759,6 +759,22 @@ def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, Decl
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Mutable/view-like ops
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def aten_FlattenOp: aten_Op<"flatten", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||||
|
let summary = "Recognized op for kernel aten::flatten";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchMutableTensor:$self,
|
||||||
|
AnyTorchIntType:$start_dim,
|
||||||
|
AnyTorchIntType:$end_dim
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchMutableTensor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>, AllowsTypeRefinement]> {
|
||||||
let summary = "Recognized op for kernel aten::copy_";
|
let summary = "Recognized op for kernel aten::copy_";
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
|
|
@ -157,14 +157,6 @@ std::map<std::string, uint64_t> ExpandOp::getStatistics() {
|
||||||
return toReturn;
|
return toReturn;
|
||||||
}
|
}
|
||||||
|
|
||||||
// flatten can be zero overhead
|
|
||||||
std::map<std::string, uint64_t> FlattenOp::getStatistics() {
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
|
||||||
toReturn["reads"] = toReturn["operand:0:activation_in"] = 0;
|
|
||||||
toReturn["writes"] = toReturn["result:0:activation_out"] = 0;
|
|
||||||
return toReturn;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, uint64_t> GatherOp::getStatistics() {
|
std::map<std::string, uint64_t> GatherOp::getStatistics() {
|
||||||
std::map<std::string, uint64_t> toReturn;
|
std::map<std::string, uint64_t> toReturn;
|
||||||
// FIXME: unimplemented
|
// FIXME: unimplemented
|
||||||
|
|
|
@ -85,6 +85,15 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (flag & KVC::kMutableTensor) {
|
||||||
|
if (!isTorchTensorType(sourceTorchType) ||
|
||||||
|
!isTorchTensorType(targetTorchType))
|
||||||
|
return None;
|
||||||
|
// If the type is already mutable, passthrough.
|
||||||
|
if (sourceMlirType.isa<Numpy::NdArrayType>())
|
||||||
|
return TypeConversion{sourceMlirType, nullptr};
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Special case promotions and conversions.
|
// TODO: Special case promotions and conversions.
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
@ -145,6 +154,20 @@ convertTorchReturnType(StringRef sourceTorchType, StringRef targetTorchType,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (flag & KVC::kMutableTensor) {
|
||||||
|
if (!isTorchTensorType(sourceTorchType) ||
|
||||||
|
!isTorchTensorType(targetTorchType)) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< " * Source or target not a Tensor type\n");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
// If the type is already mutable, passthrough.
|
||||||
|
if (sourceMlirType.isa<Numpy::NdArrayType>()) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << " * Source is already mutable\n");
|
||||||
|
return TypeConversion{sourceMlirType, nullptr};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n");
|
LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n");
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
@ -76,6 +77,9 @@ struct ValueKnowledge {
|
||||||
result.elementType = tensorType.getElementType();
|
result.elementType = tensorType.getElementType();
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
if (auto ndArrayType = type.dyn_cast<Numpy::NdArrayType>()) {
|
||||||
|
return getKnowledgeFromType(ndArrayType.toTensorType());
|
||||||
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,6 +266,39 @@ public:
|
||||||
knowledge.elementType =
|
knowledge.elementType =
|
||||||
joinElementTypes(lhs.elementType, rhs.elementType);
|
joinElementTypes(lhs.elementType, rhs.elementType);
|
||||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||||
|
} else if (auto flatten = dyn_cast<aten::FlattenOp>(op)) {
|
||||||
|
APInt startDimAP, endDimAP;
|
||||||
|
auto operand = operands[0]->getValue();
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||||
|
knowledge.elementType = operand.elementType;
|
||||||
|
if (operand.hasRank && operand.sizes.size() == 0) {
|
||||||
|
// Rank 0 is special and flattens to rank 1.
|
||||||
|
knowledge.hasRank = true;
|
||||||
|
knowledge.sizes.push_back(kUnknownSize);
|
||||||
|
} else if (operand.hasRank &&
|
||||||
|
matchPattern(flatten.start_dim(),
|
||||||
|
m_ConstantInt(&startDimAP)) &&
|
||||||
|
matchPattern(flatten.end_dim(), m_ConstantInt(&endDimAP))) {
|
||||||
|
int64_t inputRank = operand.sizes.size();
|
||||||
|
int64_t startDim = startDimAP.getSExtValue();
|
||||||
|
int64_t endDim = endDimAP.getSExtValue();
|
||||||
|
if (startDim < 0)
|
||||||
|
startDim += inputRank;
|
||||||
|
if (endDim < 0)
|
||||||
|
endDim += inputRank;
|
||||||
|
// Careful: dimension numbers might be out of bounds.
|
||||||
|
if (0 <= startDim && startDim <= (inputRank - 1) && 0 <= endDim &&
|
||||||
|
endDim <= (inputRank - 1) && startDim <= endDim) {
|
||||||
|
knowledge.hasRank = true;
|
||||||
|
for (auto i = 0; i < startDim; i++)
|
||||||
|
knowledge.sizes.push_back(operand.sizes[i]);
|
||||||
|
knowledge.sizes.push_back(kUnknownSize);
|
||||||
|
for (auto i = endDim + 1; i < inputRank; i++)
|
||||||
|
knowledge.sizes.push_back(operand.sizes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||||
}
|
}
|
||||||
// Otherwise, this is an unknown operation. Just mark all results as having
|
// Otherwise, this is an unknown operation. Just mark all results as having
|
||||||
// reached a pessimistic fixpoint.
|
// reached a pessimistic fixpoint.
|
||||||
|
@ -287,13 +324,29 @@ getTensorTypeFromKnowledge(MLIRContext *context,
|
||||||
return RankedTensorType::get(value.sizes, value.elementType);
|
return RankedTensorType::get(value.sizes, value.elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the most refined Numpy::NdArrayType compatible with ValueKnowledge.
|
||||||
|
static Numpy::NdArrayType
|
||||||
|
getNdArrayTypeFromKnowledge(MLIRContext *context,
|
||||||
|
LatticeElement<ValueKnowledge> *knowledge) {
|
||||||
|
if (!knowledge)
|
||||||
|
return Numpy::NdArrayType::get(Numpy::AnyDtypeType::get(context));
|
||||||
|
|
||||||
|
const ValueKnowledge &value = knowledge->getValue();
|
||||||
|
if (!value.hasRank)
|
||||||
|
return Numpy::NdArrayType::get(value.elementType);
|
||||||
|
return Numpy::NdArrayType::get(value.elementType,
|
||||||
|
llvm::makeArrayRef(value.sizes));
|
||||||
|
}
|
||||||
|
|
||||||
// Get a the most refined type compatible with ValueKnowledge, or null if that
|
// Get a the most refined type compatible with ValueKnowledge, or null if that
|
||||||
// is not possible.
|
// is not possible.
|
||||||
static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
||||||
if (v.getType().isa<TensorType>())
|
if (v.getType().isa<TensorType>())
|
||||||
return getTensorTypeFromKnowledge(v.getContext(),
|
return getTensorTypeFromKnowledge(v.getContext(),
|
||||||
analyzer.lookupLatticeElement(v));
|
analyzer.lookupLatticeElement(v));
|
||||||
// TODO: Support !numpy.ndarray type.
|
if (v.getType().isa<Numpy::NdArrayType>())
|
||||||
|
return getNdArrayTypeFromKnowledge(v.getContext(),
|
||||||
|
analyzer.lookupLatticeElement(v));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -308,15 +361,28 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
||||||
// Type is same as existing one? Nothing to do.
|
// Type is same as existing one? Nothing to do.
|
||||||
if (refinedType == originalType)
|
if (refinedType == originalType)
|
||||||
continue;
|
continue;
|
||||||
// If the type is a TensorType, then we have numpy.tensor_static_info_cast
|
// If we have an op that allows adding/removing static information from
|
||||||
// to add/remove the static information. We make sure to always embed the
|
// this type, then we can rewrite. We make sure to always embed the static
|
||||||
// static information in the IR, and insert the minimal number of casts
|
// information in the IR, and insert the minimal number of casts needed to
|
||||||
// needed to do so.
|
// do so.
|
||||||
// TODO: This logic should generalize easily to other types. We just
|
// TODO: For some types, we will need 2 ops here: one to add static
|
||||||
// need to know which op allows us to add static information and which op
|
// information, and the other to remove static information.
|
||||||
// allows us to remove static information (in this case, one op allows
|
// (for example, torch.unchecked_cast / torch.derefine for torch.optional
|
||||||
// both).
|
// types).
|
||||||
|
std::function<Value(Location, Type, Value)> createStaticInfoCast;
|
||||||
|
OpBuilder b(op->getBlock(), std::next(op->getIterator()));
|
||||||
if (originalType.isa<TensorType>()) {
|
if (originalType.isa<TensorType>()) {
|
||||||
|
createStaticInfoCast = [&](Location loc, Type newType,
|
||||||
|
Value v) -> Value {
|
||||||
|
return b.create<Numpy::TensorStaticInfoCastOp>(loc, newType, v);
|
||||||
|
};
|
||||||
|
} else if (originalType.isa<Numpy::NdArrayType>()) {
|
||||||
|
createStaticInfoCast = [&](Location loc, Type newType,
|
||||||
|
Value v) -> Value {
|
||||||
|
return b.create<Numpy::StaticInfoCastOp>(loc, newType, v);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (createStaticInfoCast) {
|
||||||
// Save off the original uses to avoid iterator invalidation issues
|
// Save off the original uses to avoid iterator invalidation issues
|
||||||
// or other unexpected behavior since we are creating new ops here that
|
// or other unexpected behavior since we are creating new ops here that
|
||||||
// use the value.
|
// use the value.
|
||||||
|
@ -325,14 +391,13 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
||||||
OpBuilder b(op->getBlock(), std::next(op->getIterator()));
|
OpBuilder b(op->getBlock(), std::next(op->getIterator()));
|
||||||
Value newTypedValue;
|
Value newTypedValue;
|
||||||
// Always make sure that the new static information is reflected in the
|
// Always make sure that the new static information is reflected in the
|
||||||
// IR, either by updating the type in place, or inserting a
|
// IR, either by updating the type in place, or inserting a static info
|
||||||
// numpy.tensor_static_info_cast op.
|
// cast.
|
||||||
if (allowsTypeRefinement(op)) {
|
if (allowsTypeRefinement(op)) {
|
||||||
newTypedValue = v;
|
newTypedValue = v;
|
||||||
v.setType(refinedType);
|
v.setType(refinedType);
|
||||||
} else {
|
} else {
|
||||||
newTypedValue = b.create<Numpy::TensorStaticInfoCastOp>(
|
newTypedValue = createStaticInfoCast(op->getLoc(), refinedType, v);
|
||||||
op->getLoc(), refinedType, v);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value oldTypedValue;
|
Value oldTypedValue;
|
||||||
|
@ -345,8 +410,8 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
||||||
// If needed, create a value of the original type to appease users
|
// If needed, create a value of the original type to appease users
|
||||||
// that cannot accept the new type.
|
// that cannot accept the new type.
|
||||||
if (!oldTypedValue) {
|
if (!oldTypedValue) {
|
||||||
oldTypedValue = b.create<Numpy::TensorStaticInfoCastOp>(
|
oldTypedValue =
|
||||||
op->getLoc(), originalType, newTypedValue);
|
createStaticInfoCast(op->getLoc(), originalType, newTypedValue);
|
||||||
}
|
}
|
||||||
use->set(oldTypedValue);
|
use->set(oldTypedValue);
|
||||||
}
|
}
|
||||||
|
|
|
@ -147,3 +147,18 @@ func @inplace_variant(%arg0: !numpy.ndarray<[2,2]:f32>, %arg1: !numpy.ndarray<[2
|
||||||
// CHECK: return %[[LHS_OUT]], %[[LHS_OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
// CHECK: return %[[LHS_OUT]], %[[LHS_OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||||
return %0, %arg0 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
return %0, %arg0 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @mutable_tensor(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
// CHECK: %[[CM1:.*]] = constant -1 : i64
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : i64
|
||||||
|
// CHECK: %[[RET:.*]] = "aten.flatten"(%[[ARG]], %[[C1]], %[[CM1]]) : (!numpy.ndarray<*:!numpy.any_dtype>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
// CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
func @mutable_tensor(%arg0: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
%c-1_i64 = constant -1 : i64
|
||||||
|
%c1_i64 = constant 1 : i64
|
||||||
|
%0 = torch.kernel_call "aten::flatten" %arg0, %c1_i64, %c-1_i64 : (!numpy.ndarray<*:!numpy.any_dtype>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||||
|
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
|
@ -109,6 +109,38 @@ func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Also test cast insertion for array types.
|
||||||
|
// CHECK-LABEL: func @flatten_all(
|
||||||
|
// CHECK: %[[FLATTENED:.*]] = "aten.flatten"{{.*}}-> !numpy.ndarray<[?]:f32>
|
||||||
|
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.static_info_cast %[[FLATTENED]] : !numpy.ndarray<[?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
// CHECK: return %[[SHAPE_ERASED]]
|
||||||
|
func @flatten_all(%arg0: !numpy.ndarray<[3,2,?,5]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
%end = constant -1 : i64
|
||||||
|
%start = constant 0 : i64
|
||||||
|
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[3,2,?,5]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @flatten_some(
|
||||||
|
// CHECK: "aten.flatten"{{.*}}-> !numpy.ndarray<[3,?,5]:f32>
|
||||||
|
func @flatten_some(%arg0: !numpy.ndarray<[3,2,?,5]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
%end = constant -2 : i64
|
||||||
|
%start = constant 1 : i64
|
||||||
|
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[3,2,?,5]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @flatten_rank0(
|
||||||
|
// CHECK: "aten.flatten"{{.*}}-> !numpy.ndarray<[?]:f32>
|
||||||
|
func @flatten_rank0(%arg0: !numpy.ndarray<[]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
%end = constant -1 : i64
|
||||||
|
%start = constant 0 : i64
|
||||||
|
%0 = "aten.flatten"(%arg0, %start, %end) : (!numpy.ndarray<[]:f32>, i64, i64) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @f
|
// CHECK-LABEL: func @f
|
||||||
func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor<?x3xf32>) {
|
func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor<?x3xf32>) {
|
||||||
%c1_i64 = constant 1 : i64
|
%c1_i64 = constant 1 : i64
|
||||||
|
|
Loading…
Reference in New Issue