[MLIR][TORCH] Add E2E support for aten.new_empty op

This commit decomposes `aten.new_empty` op into `aten.empty.memory_format` op.

This commit also made a dtype fix to the constant tensor allocation like ops.
Earlier the dtype for the result was inferred from the result type; now, it's
being evaluated as per the original definition of the op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/658/head snapshot-20220330.357
Vivek Khandelwal 2022-03-24 22:10:21 +05:30
parent 140babd952
commit 2597c481f6
11 changed files with 284 additions and 29 deletions

View File

@ -4204,6 +4204,34 @@ def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
}];
}
def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchIntListType:$size,
TorchOptionalIntType:$dtype,
TorchOptionalIntType:$layout,
TorchOptionalDeviceType:$device,
TorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNewEmptyOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenNewEmptyOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -27,6 +27,11 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
llvm::Optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
int64_t length);
torch_upstream::ScalarType getScalarTypeForType(Type type);
Type getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype);
// Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
Type dtype);

View File

@ -148,13 +148,24 @@ public:
auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>();
Type outElemType = resultType.getElementType();
Type resultElementType;
if (op.dtype().getType().template isa<Torch::NoneType>()) {
resultElementType = resultType.getElementType();
} else {
int64_t dtypeInt;
if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
}
// Create an uninitialized tensor of `resultSize` shape and fill it with
// value `fillVal`.
Value constVal = getConstant(rewriter, loc, fillVal, outElemType);
Value constVal = getConstant(rewriter, loc, fillVal, resultElementType);
Value outputTensor =
createInitTensor(rewriter, loc, resultSizeIndex, outElemType, constVal);
createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
return success();
}
@ -207,11 +218,24 @@ public:
for (auto size : resultSize)
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>();
auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
Type resultElementType;
if (op.dtype().getType().isa<Torch::NoneType>()) {
resultElementType = resultType.getElementType();
} else {
int64_t dtypeInt;
if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
}
// Create an uninitialized tensor of `resultSize` shape.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultSizeIndex, resultType.getElementType());
loc, resultSizeIndex, resultElementType);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
return success();
}

View File

@ -1442,8 +1442,15 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.size(),
op.dtype(), op.layout(), op.device(),
Value dtype = op.dtype();
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType =
op.self().getType().template cast<BaseTensorType>();
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.size(), dtype,
op.layout(), op.device(),
op.pin_memory());
return success();
}
@ -1536,6 +1543,27 @@ public:
};
} // namespace
namespace {
// Decompose `aten.new_empty` op into `aten.empty.memory_format` op.
class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
PatternRewriter &rewriter) const override {
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value dtype = op.dtype();
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType = op.self().getType().cast<BaseTensorType>();
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.size(), dtype, op.layout(), op.device(),
op.pin_memory(), /*memory_format=*/noneVal);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -1651,6 +1679,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<Aten_ToCopyOp>();
patterns.add<DecomposeAtenDropoutOp>(context);
target.addIllegalOp<AtenDropoutOp>();
target.addIllegalOp<AtenNewEmptyOp>();
patterns.add<DecomposeAtenNewEmptyOp>(context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -78,24 +78,6 @@ using namespace mlir::torch::Torch;
// Analysis.
// -----------------------------------------------------------------------------
static Type getTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt) {
switch (dtypeInt) {
case torch_upstream::ScalarType::Float:
return Float32Type::get(context);
case torch_upstream::ScalarType::Double:
return Float64Type::get(context);
case torch_upstream::ScalarType::Long:
return IntegerType::get(context, 64, IntegerType::Signed);
case torch_upstream::ScalarType::Int:
return IntegerType::get(context, 32, IntegerType::Signed);
case torch_upstream::ScalarType::Bool:
return IntegerType::get(context, 1);
default:
return Type();
}
}
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
}
@ -759,6 +741,8 @@ ChangeResult TypeAnalyzer::visitOperation(
return visitConstantTensorNewLikeOp<AtenNewZerosOp>(newZeros, operands);
} else if (auto newOnes = dyn_cast<AtenNewOnesOp>(op)) {
return visitConstantTensorNewLikeOp<AtenNewOnesOp>(newOnes, operands);
} else if (auto newEmpty = dyn_cast<AtenNewEmptyOp>(op)) {
return visitConstantTensorNewLikeOp<AtenNewEmptyOp>(newEmpty, operands);
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
return visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
} else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) {

View File

@ -1809,6 +1809,9 @@ module {
func @"__torch_mlir_shape_fn.aten.new_ones"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {
return %arg1 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.new_empty"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {
return %arg1 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten._to_copy"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -55,7 +55,26 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}
static Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type Torch::getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) {
switch (dtypeInt) {
case torch_upstream::ScalarType::Float:
return Float32Type::get(context);
case torch_upstream::ScalarType::Double:
return Float64Type::get(context);
case torch_upstream::ScalarType::Long:
return IntegerType::get(context, 64, signedness);
case torch_upstream::ScalarType::Int:
return IntegerType::get(context, 32, signedness);
case torch_upstream::ScalarType::Bool:
return IntegerType::get(context, 1);
default:
return Type();
}
}
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype) {
int intType = (int)getScalarTypeForType(dtype);
return rewriter.create<ConstantIntOp>(loc,

View File

@ -605,6 +605,9 @@ def atennew_zeros(self: List[int], size: List[int], dtype: Optional[int] = No
def atennew_ones(self: List[int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def atennew_empty(self: List[int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def aten_to_copy(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
return upstream_shape_helpers.unary(self)

View File

@ -383,6 +383,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::detach : (Tensor) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")

View File

@ -1006,3 +1006,140 @@ class ZeroInt64Module(torch.nn.Module):
@register_test_case(module_factory=lambda: ZeroInt64Module())
def ZeroInt64Module_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10, 4)))
# ==============================================================================
class NewEmptyModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleDefaultDtype())
def NewEmptyModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))
class NewEmptyModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.int64).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleInt2D())
def NewEmptyModuleInt2D_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
class NewEmptyModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.int64).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleInt3D())
def NewEmptyModuleInt3D_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))
class NewEmptyModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleFloat2D())
def NewEmptyModuleFloat2D_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3, 4)))
class NewEmptyModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.float32).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleFloat3D())
def NewEmptyModuleFloat3D_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)))
class NewEmptyModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32, pin_memory=False).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleFalsePinMemory())
def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)))
class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultFloatDtype())
def NewEmptyModuleNonDefaultFloatDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3).to(torch.float64))
class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype())
def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)).to(torch.int32))

View File

@ -582,7 +582,8 @@ func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %m
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
// CHECK: }
func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
@ -601,7 +602,8 @@ func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[INT4_0]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
// CHECK: return %[[RES]] : !torch.vtensor<[3,4],si64>
// CHECK: }
func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
@ -779,3 +781,22 @@ func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.new_empty
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE_0]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.new_empty %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
return %1 : !torch.vtensor<[2,3],f32>
}