mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.zeros op
This commit adds lowering of `aten.zeros` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/465/head
parent
977b1b03ea
commit
9958cf08b6
|
@ -900,7 +900,6 @@ class NumelModule(torch.nn.Module):
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.numel(input)
|
return torch.numel(input)
|
||||||
|
|
||||||
|
@ -918,10 +917,89 @@ class NumelZeroRankModule(torch.nn.Module):
|
||||||
None,
|
None,
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return torch.numel(input)
|
return torch.numel(input)
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NumelZeroRankModule())
|
@register_test_case(module_factory=lambda: NumelZeroRankModule())
|
||||||
def NumelZeroRankModule_basic(module, tu: TestUtils):
|
def NumelZeroRankModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(10,[]))
|
module.forward(torch.randint(10,[]))
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleInt2D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, dtype=torch.int64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
|
||||||
|
def ZerosModuleInt2D_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleInt3D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, 5, dtype=torch.int64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
|
||||||
|
def ZerosModuleInt3D_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleFloat2D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, dtype=torch.float32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
|
||||||
|
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleFloat3D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, 5, dtype=torch.float32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
|
||||||
|
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleFalsePinMemory(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
|
||||||
|
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
|
@ -3269,24 +3269,44 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenOnesOp : public OpConversionPattern<AtenOnesOp> {
|
// Converts AtenOnesOp and AtenZerosOp.
|
||||||
public:
|
struct ConvertAtenOnesZerosOp : ConversionPattern {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
ConvertAtenOnesZerosOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||||
|
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOnesOp op, OpAdaptor adaptor,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (!isa<AtenOnesOp, AtenZerosOp>(op))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"not a supported ones or zeros op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = op.getLoc();
|
Location loc = op->getLoc();
|
||||||
|
|
||||||
|
SmallVector<Value, 3> opArguments;
|
||||||
|
int64_t elementValue;
|
||||||
|
|
||||||
|
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
|
||||||
|
opArguments.insert(opArguments.end(),
|
||||||
|
{onesOp.size(), onesOp.layout(), onesOp.pin_memory()});
|
||||||
|
elementValue = 1;
|
||||||
|
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
|
||||||
|
opArguments.insert(opArguments.end(), {zerosOp.size(), zerosOp.layout(),
|
||||||
|
zerosOp.pin_memory()});
|
||||||
|
elementValue = 0;
|
||||||
|
}
|
||||||
|
|
||||||
// We ignore device, but add simple asserts for unimplemented kwargs
|
// We ignore device, but add simple asserts for unimplemented kwargs
|
||||||
if (!op.layout().getType().isa<Torch::NoneType>())
|
if (!opArguments[1].getType().isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only default layout is supported");
|
"only default layout is supported");
|
||||||
|
|
||||||
bool pinMemory = false;
|
bool pinMemory = false;
|
||||||
if (!op.pin_memory().getType().isa<Torch::NoneType>() &&
|
if (!opArguments[2].getType().isa<Torch::NoneType>() &&
|
||||||
!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory))) {
|
!matchPattern(opArguments[2], m_TorchConstantBool(&pinMemory))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "pin_memory must be constant bool or None");
|
op, "pin_memory must be constant bool or None");
|
||||||
}
|
}
|
||||||
|
@ -3294,7 +3314,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
|
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
|
||||||
|
|
||||||
SmallVector<Value> size, sizeIndex;
|
SmallVector<Value> size, sizeIndex;
|
||||||
if (!getListConstructElements(op.size(), size)) {
|
if (!getListConstructElements(opArguments[0], size)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "size must be created by ListConstruct");
|
op, "size must be created by ListConstruct");
|
||||||
}
|
}
|
||||||
|
@ -3303,21 +3323,24 @@ public:
|
||||||
sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i]));
|
sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i]));
|
||||||
|
|
||||||
RankedTensorType newResultType =
|
RankedTensorType newResultType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
getTypeConverter()
|
||||||
|
->convertType(op->getResult(0).getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
Type outElementType = newResultType.getElementType();
|
Type outElementType = newResultType.getElementType();
|
||||||
|
|
||||||
Value one = rewriter.create<arith::ConstantOp>(
|
Value constantOp = rewriter.create<arith::ConstantOp>(
|
||||||
loc, outElementType,
|
loc, outElementType,
|
||||||
(outElementType.isa<mlir::FloatType>()
|
(outElementType.isa<mlir::FloatType>()
|
||||||
? rewriter.getFloatAttr(outElementType, 1).cast<mlir::Attribute>()
|
? rewriter.getFloatAttr(outElementType, elementValue)
|
||||||
: rewriter.getIntegerAttr(outElementType, 1)
|
.cast<mlir::Attribute>()
|
||||||
|
: rewriter.getIntegerAttr(outElementType, elementValue)
|
||||||
.cast<mlir::Attribute>()));
|
.cast<mlir::Attribute>()));
|
||||||
Value outTensor = rewriter
|
Value outTensor = rewriter
|
||||||
.create<linalg::InitTensorOp>(
|
.create<linalg::InitTensorOp>(
|
||||||
loc, sizeIndex, newResultType.getElementType())
|
loc, sizeIndex, newResultType.getElementType())
|
||||||
.getResult();
|
.getResult();
|
||||||
Value fillOp =
|
Value fillOp = rewriter.create<linalg::FillOp>(loc, constantOp, outTensor)
|
||||||
rewriter.create<linalg::FillOp>(loc, one, outTensor).getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, fillOp);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, fillOp);
|
||||||
|
|
||||||
|
@ -3452,8 +3475,8 @@ public:
|
||||||
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
|
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenEmbeddingOp>();
|
target.addIllegalOp<AtenEmbeddingOp>();
|
||||||
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenOnesOp>();
|
target.addIllegalOp<AtenOnesOp, AtenZerosOp>();
|
||||||
patterns.add<ConvertAtenOnesOp>(typeConverter, context);
|
patterns.add<ConvertAtenOnesZerosOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenContiguousOp>();
|
target.addIllegalOp<AtenContiguousOp>();
|
||||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenIntTensorOp>();
|
target.addIllegalOp<AtenIntTensorOp>();
|
||||||
|
|
|
@ -458,7 +458,6 @@ public:
|
||||||
return visitAtenNllLossForwardOp(nllForwardOp, operands);
|
return visitAtenNllLossForwardOp(nllForwardOp, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Otherwise, this is an unknown operation. Just mark all results as
|
// Otherwise, this is an unknown operation. Just mark all results as
|
||||||
// having reached a pessimistic fixpoint.
|
// having reached a pessimistic fixpoint.
|
||||||
return markAllPessimisticFixpoint(op->getResults());
|
return markAllPessimisticFixpoint(op->getResults());
|
||||||
|
|
Loading…
Reference in New Issue