mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.new_empty op
This commit decomposes `aten.new_empty` op into `aten.empty.memory_format` op. This commit also made a dtype fix to the constant tensor allocation like ops. Earlier the dtype for the result was inferred from the result type; now, it's being evaluated as per the original definition of the op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/658/head snapshot-20220330.357
parent
140babd952
commit
2597c481f6
|
@ -4204,6 +4204,34 @@ def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
TorchIntListType:$size,
|
||||||
|
TorchOptionalIntType:$dtype,
|
||||||
|
TorchOptionalIntType:$layout,
|
||||||
|
TorchOptionalDeviceType:$device,
|
||||||
|
TorchOptionalBoolType:$pin_memory
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNewEmptyOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void AtenNewEmptyOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
|
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -27,6 +27,11 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
||||||
llvm::Optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
|
llvm::Optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
|
||||||
int64_t length);
|
int64_t length);
|
||||||
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
||||||
|
Type getTypeForScalarType(
|
||||||
|
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||||
|
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
||||||
|
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||||
|
Type dtype);
|
||||||
// Helper to convert a tensor to a specific scalar type.
|
// Helper to convert a tensor to a specific scalar type.
|
||||||
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
Type dtype);
|
Type dtype);
|
||||||
|
|
|
@ -148,13 +148,24 @@ public:
|
||||||
|
|
||||||
auto resultType = typeConverter->convertType(op.getType())
|
auto resultType = typeConverter->convertType(op.getType())
|
||||||
.template cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
Type outElemType = resultType.getElementType();
|
Type resultElementType;
|
||||||
|
if (op.dtype().getType().template isa<Torch::NoneType>()) {
|
||||||
|
resultElementType = resultType.getElementType();
|
||||||
|
} else {
|
||||||
|
int64_t dtypeInt;
|
||||||
|
if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtypeInt)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: dtype must be a constant integer or none");
|
||||||
|
resultElementType = getTypeForScalarType(
|
||||||
|
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||||
|
IntegerType::Signless);
|
||||||
|
}
|
||||||
|
|
||||||
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
||||||
// value `fillVal`.
|
// value `fillVal`.
|
||||||
Value constVal = getConstant(rewriter, loc, fillVal, outElemType);
|
Value constVal = getConstant(rewriter, loc, fillVal, resultElementType);
|
||||||
Value outputTensor =
|
Value outputTensor =
|
||||||
createInitTensor(rewriter, loc, resultSizeIndex, outElemType, constVal);
|
createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal);
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -207,11 +218,24 @@ public:
|
||||||
for (auto size : resultSize)
|
for (auto size : resultSize)
|
||||||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||||
|
|
||||||
auto resultType = typeConverter->convertType(op.getType())
|
auto resultType =
|
||||||
.template cast<RankedTensorType>();
|
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
Type resultElementType;
|
||||||
|
if (op.dtype().getType().isa<Torch::NoneType>()) {
|
||||||
|
resultElementType = resultType.getElementType();
|
||||||
|
} else {
|
||||||
|
int64_t dtypeInt;
|
||||||
|
if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtypeInt)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: dtype must be a constant integer or none");
|
||||||
|
resultElementType = getTypeForScalarType(
|
||||||
|
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||||
|
IntegerType::Signless);
|
||||||
|
}
|
||||||
|
|
||||||
// Create an uninitialized tensor of `resultSize` shape.
|
// Create an uninitialized tensor of `resultSize` shape.
|
||||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
loc, resultSizeIndex, resultType.getElementType());
|
loc, resultSizeIndex, resultElementType);
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1442,8 +1442,15 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
||||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(OpTy op,
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.size(),
|
Value dtype = op.dtype();
|
||||||
op.dtype(), op.layout(), op.device(),
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||||
|
BaseTensorType tensorType =
|
||||||
|
op.self().getType().template cast<BaseTensorType>();
|
||||||
|
dtype =
|
||||||
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.size(), dtype,
|
||||||
|
op.layout(), op.device(),
|
||||||
op.pin_memory());
|
op.pin_memory());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1536,6 +1543,27 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose `aten.new_empty` op into `aten.empty.memory_format` op.
|
||||||
|
class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||||
|
Value dtype = op.dtype();
|
||||||
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||||
|
BaseTensorType tensorType = op.self().getType().cast<BaseTensorType>();
|
||||||
|
dtype =
|
||||||
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||||||
|
op, op.getType(), op.size(), dtype, op.layout(), op.device(),
|
||||||
|
op.pin_memory(), /*memory_format=*/noneVal);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -1651,6 +1679,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<Aten_ToCopyOp>();
|
target.addIllegalOp<Aten_ToCopyOp>();
|
||||||
patterns.add<DecomposeAtenDropoutOp>(context);
|
patterns.add<DecomposeAtenDropoutOp>(context);
|
||||||
target.addIllegalOp<AtenDropoutOp>();
|
target.addIllegalOp<AtenDropoutOp>();
|
||||||
|
target.addIllegalOp<AtenNewEmptyOp>();
|
||||||
|
patterns.add<DecomposeAtenNewEmptyOp>(context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
|
|
@ -78,24 +78,6 @@ using namespace mlir::torch::Torch;
|
||||||
// Analysis.
|
// Analysis.
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
static Type getTypeForScalarType(MLIRContext *context,
|
|
||||||
torch_upstream::ScalarType dtypeInt) {
|
|
||||||
switch (dtypeInt) {
|
|
||||||
case torch_upstream::ScalarType::Float:
|
|
||||||
return Float32Type::get(context);
|
|
||||||
case torch_upstream::ScalarType::Double:
|
|
||||||
return Float64Type::get(context);
|
|
||||||
case torch_upstream::ScalarType::Long:
|
|
||||||
return IntegerType::get(context, 64, IntegerType::Signed);
|
|
||||||
case torch_upstream::ScalarType::Int:
|
|
||||||
return IntegerType::get(context, 32, IntegerType::Signed);
|
|
||||||
case torch_upstream::ScalarType::Bool:
|
|
||||||
return IntegerType::get(context, 1);
|
|
||||||
default:
|
|
||||||
return Type();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
|
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
|
||||||
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||||
}
|
}
|
||||||
|
@ -759,6 +741,8 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
return visitConstantTensorNewLikeOp<AtenNewZerosOp>(newZeros, operands);
|
return visitConstantTensorNewLikeOp<AtenNewZerosOp>(newZeros, operands);
|
||||||
} else if (auto newOnes = dyn_cast<AtenNewOnesOp>(op)) {
|
} else if (auto newOnes = dyn_cast<AtenNewOnesOp>(op)) {
|
||||||
return visitConstantTensorNewLikeOp<AtenNewOnesOp>(newOnes, operands);
|
return visitConstantTensorNewLikeOp<AtenNewOnesOp>(newOnes, operands);
|
||||||
|
} else if (auto newEmpty = dyn_cast<AtenNewEmptyOp>(op)) {
|
||||||
|
return visitConstantTensorNewLikeOp<AtenNewEmptyOp>(newEmpty, operands);
|
||||||
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
|
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
|
||||||
return visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
|
return visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
|
||||||
} else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) {
|
} else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) {
|
||||||
|
|
|
@ -1809,6 +1809,9 @@ module {
|
||||||
func @"__torch_mlir_shape_fn.aten.new_ones"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {
|
func @"__torch_mlir_shape_fn.aten.new_ones"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {
|
||||||
return %arg1 : !torch.list<int>
|
return %arg1 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func @"__torch_mlir_shape_fn.aten.new_empty"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {
|
||||||
|
return %arg1 : !torch.list<int>
|
||||||
|
}
|
||||||
func @"__torch_mlir_shape_fn.aten._to_copy"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {
|
func @"__torch_mlir_shape_fn.aten._to_copy"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
|
|
@ -55,7 +55,26 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
Type Torch::getTypeForScalarType(
|
||||||
|
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||||
|
mlir::IntegerType::SignednessSemantics signedness) {
|
||||||
|
switch (dtypeInt) {
|
||||||
|
case torch_upstream::ScalarType::Float:
|
||||||
|
return Float32Type::get(context);
|
||||||
|
case torch_upstream::ScalarType::Double:
|
||||||
|
return Float64Type::get(context);
|
||||||
|
case torch_upstream::ScalarType::Long:
|
||||||
|
return IntegerType::get(context, 64, signedness);
|
||||||
|
case torch_upstream::ScalarType::Int:
|
||||||
|
return IntegerType::get(context, 32, signedness);
|
||||||
|
case torch_upstream::ScalarType::Bool:
|
||||||
|
return IntegerType::get(context, 1);
|
||||||
|
default:
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||||
Type dtype) {
|
Type dtype) {
|
||||||
int intType = (int)getScalarTypeForType(dtype);
|
int intType = (int)getScalarTypeForType(dtype);
|
||||||
return rewriter.create<ConstantIntOp>(loc,
|
return rewriter.create<ConstantIntOp>(loc,
|
||||||
|
|
|
@ -605,6 +605,9 @@ def aten〇new_zeros(self: List[int], size: List[int], dtype: Optional[int] = No
|
||||||
def aten〇new_ones(self: List[int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
def aten〇new_ones(self: List[int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
return size
|
return size
|
||||||
|
|
||||||
|
def aten〇new_empty(self: List[int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
|
return size
|
||||||
|
|
||||||
def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
|
def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
|
||||||
return upstream_shape_helpers.unary(self)
|
return upstream_shape_helpers.unary(self)
|
||||||
|
|
||||||
|
|
|
@ -383,6 +383,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::detach : (Tensor) -> (Tensor)")
|
emit("aten::detach : (Tensor) -> (Tensor)")
|
||||||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||||
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
|
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
|
|
|
@ -1006,3 +1006,140 @@ class ZeroInt64Module(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ZeroInt64Module())
|
@register_test_case(module_factory=lambda: ZeroInt64Module())
|
||||||
def ZeroInt64Module_basic(module, tu: TestUtils):
|
def ZeroInt64Module_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(100, (10, 4)))
|
module.forward(torch.randint(100, (10, 4)))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class NewEmptyModuleDefaultDtype(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NewEmptyModuleDefaultDtype())
|
||||||
|
def NewEmptyModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3))
|
||||||
|
|
||||||
|
|
||||||
|
class NewEmptyModuleInt2D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.int64).fill_(0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NewEmptyModuleInt2D())
|
||||||
|
def NewEmptyModuleInt2D_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class NewEmptyModuleInt3D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.int64).fill_(0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NewEmptyModuleInt3D())
|
||||||
|
def NewEmptyModuleInt3D_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3))
|
||||||
|
|
||||||
|
|
||||||
|
class NewEmptyModuleFloat2D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32).fill_(0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NewEmptyModuleFloat2D())
|
||||||
|
def NewEmptyModuleFloat2D_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (2, 3, 4)))
|
||||||
|
|
||||||
|
|
||||||
|
class NewEmptyModuleFloat3D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.float32).fill_(0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NewEmptyModuleFloat3D())
|
||||||
|
def NewEmptyModuleFloat3D_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (2, 3)))
|
||||||
|
|
||||||
|
|
||||||
|
class NewEmptyModuleFalsePinMemory(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32, pin_memory=False).fill_(0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NewEmptyModuleFalsePinMemory())
|
||||||
|
def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (2, 3)))
|
||||||
|
|
||||||
|
|
||||||
|
class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultFloatDtype())
|
||||||
|
def NewEmptyModuleNonDefaultFloatDtype_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3).to(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
|
class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype())
|
||||||
|
def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (2, 3)).to(torch.int32))
|
||||||
|
|
|
@ -582,7 +582,8 @@ func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %m
|
||||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||||
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
|
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
|
func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||||
|
@ -601,7 +602,8 @@ func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
||||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||||
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
|
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[INT4_0]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
|
||||||
// CHECK: return %[[RES]] : !torch.vtensor<[3,4],si64>
|
// CHECK: return %[[RES]] : !torch.vtensor<[3,4],si64>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
|
func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
|
||||||
|
@ -779,3 +781,22 @@ func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
||||||
return %0 : !torch.vtensor<[?,?],f32>
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten.new_empty
|
||||||
|
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE_0]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||||
|
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
|
||||||
|
func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.new_empty %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||||
|
return %1 : !torch.vtensor<[2,3],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue