mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.new_empty op
This commit decomposes `aten.new_empty` op into `aten.empty.memory_format` op. This commit also made a dtype fix to the constant tensor allocation like ops. Earlier the dtype for the result was inferred from the result type; now, it's being evaluated as per the original definition of the op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/658/head snapshot-20220330.357
parent
140babd952
commit
2597c481f6
|
@ -4204,6 +4204,34 @@ def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
TorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNewEmptyOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenNewEmptyOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -27,6 +27,11 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
|||
llvm::Optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
|
||||
int64_t length);
|
||||
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
||||
Type getTypeForScalarType(
|
||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
||||
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||
Type dtype);
|
||||
// Helper to convert a tensor to a specific scalar type.
|
||||
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
||||
Type dtype);
|
||||
|
|
|
@ -148,13 +148,24 @@ public:
|
|||
|
||||
auto resultType = typeConverter->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
Type outElemType = resultType.getElementType();
|
||||
Type resultElementType;
|
||||
if (op.dtype().getType().template isa<Torch::NoneType>()) {
|
||||
resultElementType = resultType.getElementType();
|
||||
} else {
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtypeInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dtype must be a constant integer or none");
|
||||
resultElementType = getTypeForScalarType(
|
||||
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||
IntegerType::Signless);
|
||||
}
|
||||
|
||||
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
||||
// value `fillVal`.
|
||||
Value constVal = getConstant(rewriter, loc, fillVal, outElemType);
|
||||
Value constVal = getConstant(rewriter, loc, fillVal, resultElementType);
|
||||
Value outputTensor =
|
||||
createInitTensor(rewriter, loc, resultSizeIndex, outElemType, constVal);
|
||||
createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
||||
return success();
|
||||
}
|
||||
|
@ -207,11 +218,24 @@ public:
|
|||
for (auto size : resultSize)
|
||||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||
|
||||
auto resultType = typeConverter->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
Type resultElementType;
|
||||
if (op.dtype().getType().isa<Torch::NoneType>()) {
|
||||
resultElementType = resultType.getElementType();
|
||||
} else {
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtypeInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dtype must be a constant integer or none");
|
||||
resultElementType = getTypeForScalarType(
|
||||
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||
IntegerType::Signless);
|
||||
}
|
||||
|
||||
// Create an uninitialized tensor of `resultSize` shape.
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultSizeIndex, resultType.getElementType());
|
||||
loc, resultSizeIndex, resultElementType);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1442,8 +1442,15 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
|||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.size(),
|
||||
op.dtype(), op.layout(), op.device(),
|
||||
Value dtype = op.dtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType tensorType =
|
||||
op.self().getType().template cast<BaseTensorType>();
|
||||
dtype =
|
||||
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.size(), dtype,
|
||||
op.layout(), op.device(),
|
||||
op.pin_memory());
|
||||
return success();
|
||||
}
|
||||
|
@ -1536,6 +1543,27 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.new_empty` op into `aten.empty.memory_format` op.
|
||||
class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
Value dtype = op.dtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType tensorType = op.self().getType().cast<BaseTensorType>();
|
||||
dtype =
|
||||
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||||
op, op.getType(), op.size(), dtype, op.layout(), op.device(),
|
||||
op.pin_memory(), /*memory_format=*/noneVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -1651,6 +1679,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<Aten_ToCopyOp>();
|
||||
patterns.add<DecomposeAtenDropoutOp>(context);
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
target.addIllegalOp<AtenNewEmptyOp>();
|
||||
patterns.add<DecomposeAtenNewEmptyOp>(context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -78,24 +78,6 @@ using namespace mlir::torch::Torch;
|
|||
// Analysis.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
static Type getTypeForScalarType(MLIRContext *context,
|
||||
torch_upstream::ScalarType dtypeInt) {
|
||||
switch (dtypeInt) {
|
||||
case torch_upstream::ScalarType::Float:
|
||||
return Float32Type::get(context);
|
||||
case torch_upstream::ScalarType::Double:
|
||||
return Float64Type::get(context);
|
||||
case torch_upstream::ScalarType::Long:
|
||||
return IntegerType::get(context, 64, IntegerType::Signed);
|
||||
case torch_upstream::ScalarType::Int:
|
||||
return IntegerType::get(context, 32, IntegerType::Signed);
|
||||
case torch_upstream::ScalarType::Bool:
|
||||
return IntegerType::get(context, 1);
|
||||
default:
|
||||
return Type();
|
||||
}
|
||||
}
|
||||
|
||||
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
|
||||
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
}
|
||||
|
@ -759,6 +741,8 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
return visitConstantTensorNewLikeOp<AtenNewZerosOp>(newZeros, operands);
|
||||
} else if (auto newOnes = dyn_cast<AtenNewOnesOp>(op)) {
|
||||
return visitConstantTensorNewLikeOp<AtenNewOnesOp>(newOnes, operands);
|
||||
} else if (auto newEmpty = dyn_cast<AtenNewEmptyOp>(op)) {
|
||||
return visitConstantTensorNewLikeOp<AtenNewEmptyOp>(newEmpty, operands);
|
||||
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
|
||||
return visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
|
||||
} else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) {
|
||||
|
|
|
@ -1809,6 +1809,9 @@ module {
|
|||
func @"__torch_mlir_shape_fn.aten.new_ones"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {
|
||||
return %arg1 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.new_empty"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {
|
||||
return %arg1 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten._to_copy"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -55,7 +55,26 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
|||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||
}
|
||||
|
||||
static Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||
Type Torch::getTypeForScalarType(
|
||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness) {
|
||||
switch (dtypeInt) {
|
||||
case torch_upstream::ScalarType::Float:
|
||||
return Float32Type::get(context);
|
||||
case torch_upstream::ScalarType::Double:
|
||||
return Float64Type::get(context);
|
||||
case torch_upstream::ScalarType::Long:
|
||||
return IntegerType::get(context, 64, signedness);
|
||||
case torch_upstream::ScalarType::Int:
|
||||
return IntegerType::get(context, 32, signedness);
|
||||
case torch_upstream::ScalarType::Bool:
|
||||
return IntegerType::get(context, 1);
|
||||
default:
|
||||
return Type();
|
||||
}
|
||||
}
|
||||
|
||||
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||
Type dtype) {
|
||||
int intType = (int)getScalarTypeForType(dtype);
|
||||
return rewriter.create<ConstantIntOp>(loc,
|
||||
|
|
|
@ -605,6 +605,9 @@ def aten〇new_zeros(self: List[int], size: List[int], dtype: Optional[int] = No
|
|||
def aten〇new_ones(self: List[int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
def aten〇new_empty(self: List[int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
|
|
|
@ -383,6 +383,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::detach : (Tensor) -> (Tensor)")
|
||||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
|
|
|
@ -1006,3 +1006,140 @@ class ZeroInt64Module(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ZeroInt64Module())
|
||||
def ZeroInt64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10, 4)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NewEmptyModuleDefaultDtype(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleDefaultDtype())
|
||||
def NewEmptyModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3))
|
||||
|
||||
|
||||
class NewEmptyModuleInt2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.int64).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleInt2D())
|
||||
def NewEmptyModuleInt2D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
class NewEmptyModuleInt3D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.int64).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleInt3D())
|
||||
def NewEmptyModuleInt3D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3))
|
||||
|
||||
|
||||
class NewEmptyModuleFloat2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleFloat2D())
|
||||
def NewEmptyModuleFloat2D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3, 4)))
|
||||
|
||||
|
||||
class NewEmptyModuleFloat3D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.float32).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleFloat3D())
|
||||
def NewEmptyModuleFloat3D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)))
|
||||
|
||||
|
||||
class NewEmptyModuleFalsePinMemory(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32, pin_memory=False).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleFalsePinMemory())
|
||||
def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)))
|
||||
|
||||
|
||||
class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultFloatDtype())
|
||||
def NewEmptyModuleNonDefaultFloatDtype_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3).to(torch.float64))
|
||||
|
||||
|
||||
class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype())
|
||||
def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)).to(torch.int32))
|
||||
|
|
|
@ -582,7 +582,8 @@ func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %m
|
|||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
|
@ -601,7 +602,8 @@ func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
|||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
|
||||
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[INT4_0]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[3,4],si64>
|
||||
// CHECK: }
|
||||
func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
|
||||
|
@ -779,3 +781,22 @@ func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
|||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.new_empty
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE_0]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
|
||||
func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
%none = torch.constant.none
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.new_empty %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||
return %1 : !torch.vtensor<[2,3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue