[MLIR][TORCH] Add E2E support for aten.zeros op

This commit adds lowering of `aten.zeros` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/465/head
Vivek Khandelwal 2021-12-03 22:31:01 +05:30
parent 977b1b03ea
commit 9958cf08b6
3 changed files with 121 additions and 21 deletions

View File

@ -900,7 +900,6 @@ class NumelModule(torch.nn.Module):
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, input):
return torch.numel(input)
@ -918,10 +917,89 @@ class NumelZeroRankModule(torch.nn.Module):
None,
([], torch.int64, True),
])
def forward(self, input):
return torch.numel(input)
@register_test_case(module_factory=lambda: NumelZeroRankModule())
def NumelZeroRankModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10,[]))
class ZerosModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.int64)
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
def ZerosModuleInt2D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, 5, dtype=torch.int64)
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
def ZerosModuleInt3D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, 5, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()

View File

@ -3269,24 +3269,44 @@ public:
} // namespace
namespace {
class ConvertAtenOnesOp : public OpConversionPattern<AtenOnesOp> {
public:
using OpConversionPattern::OpConversionPattern;
// Converts AtenOnesOp and AtenZerosOp.
struct ConvertAtenOnesZerosOp : ConversionPattern {
ConvertAtenOnesZerosOp(TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
LogicalResult
matchAndRewrite(AtenOnesOp op, OpAdaptor adaptor,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenOnesOp, AtenZerosOp>(op))
return rewriter.notifyMatchFailure(op,
"not a supported ones or zeros op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Location loc = op->getLoc();
SmallVector<Value, 3> opArguments;
int64_t elementValue;
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
opArguments.insert(opArguments.end(),
{onesOp.size(), onesOp.layout(), onesOp.pin_memory()});
elementValue = 1;
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
opArguments.insert(opArguments.end(), {zerosOp.size(), zerosOp.layout(),
zerosOp.pin_memory()});
elementValue = 0;
}
// We ignore device, but add simple asserts for unimplemented kwargs
if (!op.layout().getType().isa<Torch::NoneType>())
if (!opArguments[1].getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op,
"only default layout is supported");
bool pinMemory = false;
if (!op.pin_memory().getType().isa<Torch::NoneType>() &&
!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory))) {
if (!opArguments[2].getType().isa<Torch::NoneType>() &&
!matchPattern(opArguments[2], m_TorchConstantBool(&pinMemory))) {
return rewriter.notifyMatchFailure(
op, "pin_memory must be constant bool or None");
}
@ -3294,7 +3314,7 @@ public:
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
SmallVector<Value> size, sizeIndex;
if (!getListConstructElements(op.size(), size)) {
if (!getListConstructElements(opArguments[0], size)) {
return rewriter.notifyMatchFailure(
op, "size must be created by ListConstruct");
}
@ -3303,21 +3323,24 @@ public:
sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i]));
RankedTensorType newResultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type outElementType = newResultType.getElementType();
Value one = rewriter.create<arith::ConstantOp>(
Value constantOp = rewriter.create<arith::ConstantOp>(
loc, outElementType,
(outElementType.isa<mlir::FloatType>()
? rewriter.getFloatAttr(outElementType, 1).cast<mlir::Attribute>()
: rewriter.getIntegerAttr(outElementType, 1)
? rewriter.getFloatAttr(outElementType, elementValue)
.cast<mlir::Attribute>()
: rewriter.getIntegerAttr(outElementType, elementValue)
.cast<mlir::Attribute>()));
Value outTensor = rewriter
.create<linalg::InitTensorOp>(
loc, sizeIndex, newResultType.getElementType())
.getResult();
Value fillOp =
rewriter.create<linalg::FillOp>(loc, one, outTensor).getResult(0);
Value fillOp = rewriter.create<linalg::FillOp>(loc, constantOp, outTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, fillOp);
@ -3452,8 +3475,8 @@ public:
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
target.addIllegalOp<AtenEmbeddingOp>();
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
target.addIllegalOp<AtenOnesOp>();
patterns.add<ConvertAtenOnesOp>(typeConverter, context);
target.addIllegalOp<AtenOnesOp, AtenZerosOp>();
patterns.add<ConvertAtenOnesZerosOp>(typeConverter, context);
target.addIllegalOp<AtenContiguousOp>();
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp>();

View File

@ -458,7 +458,6 @@ public:
return visitAtenNllLossForwardOp(nllForwardOp, operands);
}
// Otherwise, this is an unknown operation. Just mark all results as
// having reached a pessimistic fixpoint.
return markAllPessimisticFixpoint(op->getResults());