mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add and templatize lowering of [`aten.zeros|aten.ones|aten.empty`] ops
- Templatize `aten.zeros` and `aten.ones` ops lowering. - Add E2E support for `aten.empty` op. - Add Integer type support in `aten.mul.Scalar` op lowering. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/480/head
parent
528354de84
commit
8d4879feb0
|
@ -637,6 +637,57 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class EmptyIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return 0 * torch.empty((3, 4), dtype=torch.int64)
|
||||
|
||||
@register_test_case(module_factory=lambda: EmptyIntModule())
|
||||
def EmptyModule_int(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class EmptyFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.abs(torch.empty((3, 4), dtype=torch.float32)) > -1.0
|
||||
|
||||
@register_test_case(module_factory=lambda: EmptyFloatModule())
|
||||
def EmptyModule_float(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class EmptyFalsePinMemoryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.abs(torch.empty((3, 4), dtype=torch.float32,
|
||||
pin_memory=False)) > -1.0
|
||||
|
||||
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
|
||||
def EmptyModule_falsePinMemory(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ContiguousModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -443,8 +443,24 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMulScalarIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
class ElementwiseMulScalarModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.mul(x, 4)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule())
|
||||
def ElementwiseMulScalarModule_int(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
|
||||
|
||||
class ElementwiseMulScalarFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -456,11 +472,27 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.mul(x, 100.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
|
||||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarFloatModule())
|
||||
def ElementwiseMulScalarModule_float(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
class ElementwiseMulScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.mul(x, 8.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
|
||||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
|
||||
|
||||
class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -189,6 +189,21 @@ static Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|||
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
|
||||
return b.create<linalg::FillOp>(loc, initElem, initTensor).getResult(0);
|
||||
}
|
||||
// Creates a constant of type `elemType` with value `val`.
|
||||
static Value getConstant(OpBuilder &b, Location loc, int64_t val,
|
||||
Type elemType) {
|
||||
Attribute attr = {};
|
||||
if (elemType.isa<mlir::FloatType>())
|
||||
attr = b.getFloatAttr(elemType, val);
|
||||
if (elemType.isa<mlir::IndexType>())
|
||||
attr = b.getIndexAttr(val);
|
||||
if (elemType.isa<mlir::IntegerType>())
|
||||
attr = b.getIntegerAttr(
|
||||
elemType, APInt(elemType.cast<IntegerType>().getWidth(), val));
|
||||
if (!attr)
|
||||
return nullptr;
|
||||
return b.create<arith::ConstantOp>(loc, elemType, attr);
|
||||
}
|
||||
|
||||
// Helper function to caculate the output tensor dims for convolution-like ops.
|
||||
// Along each dim:
|
||||
|
@ -1828,13 +1843,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Type dtype = converter->convertType(mulScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
mulScalar.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value self = payloadArgs[0];
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
return b.create<arith::MulFOp>(loc, self, other);
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
||||
if (dtype.isa<mlir::IntegerType>())
|
||||
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||
mulScalar.emitError("unimplemented: Only integer/float dtype supported");
|
||||
return nullptr;
|
||||
}
|
||||
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||
Value input = payloadArgs[0];
|
||||
|
@ -3417,83 +3433,68 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Converts AtenOnesOp and AtenZerosOp.
|
||||
struct ConvertAtenOnesZerosOp : ConversionPattern {
|
||||
ConvertAtenOnesZerosOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
// Converts constant tensor allocation like ops.
|
||||
template <typename OpTy>
|
||||
class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenOnesOp, AtenZerosOp>(op))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"not a supported ones or zeros op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
Value size, layout, pin_memory;
|
||||
int64_t elementValue;
|
||||
|
||||
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
|
||||
size = onesOp.size();
|
||||
layout = onesOp.layout();
|
||||
pin_memory = onesOp.pin_memory();
|
||||
elementValue = 1;
|
||||
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
|
||||
size = zerosOp.size();
|
||||
layout = zerosOp.layout();
|
||||
pin_memory = zerosOp.pin_memory();
|
||||
elementValue = 0;
|
||||
}
|
||||
|
||||
// We ignore device, but add simple asserts for unimplemented kwargs
|
||||
if (!layout.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only default layout is supported");
|
||||
|
||||
bool pinMemory = false;
|
||||
if (!pin_memory.getType().isa<Torch::NoneType>() &&
|
||||
!matchPattern(pin_memory, m_TorchConstantBool(&pinMemory))) {
|
||||
// Currently memory pinning and layout features are not supported.
|
||||
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "pin_memory must be constant bool or None");
|
||||
}
|
||||
if (pinMemory)
|
||||
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
|
||||
|
||||
SmallVector<Value> sizes, sizeIndex;
|
||||
if (!getListConstructElements(size, sizes)) {
|
||||
op, "unimplemented: only default layout is supported");
|
||||
bool pinMemory;
|
||||
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "size must be created by ListConstruct");
|
||||
op, "unimplemented: pin_memory must be either None or false");
|
||||
}
|
||||
sizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), sizes);
|
||||
for (size_t i = 0; i < sizes.size(); i++)
|
||||
sizeIndex.push_back(castIntToIndex(rewriter, loc, sizes[i]));
|
||||
|
||||
RankedTensorType newResultType =
|
||||
getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Type outElementType = newResultType.getElementType();
|
||||
// Memory formats are not supported in the case of `AtenEmptyMemoryFormat`.
|
||||
if constexpr (std::is_same<OpTy, AtenEmptyMemoryFormatOp>::value) {
|
||||
if (!op.memory_format().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default memory format is supported");
|
||||
}
|
||||
|
||||
Value constantOp = rewriter.create<arith::ConstantOp>(
|
||||
loc, outElementType,
|
||||
(outElementType.isa<mlir::FloatType>()
|
||||
? rewriter.getFloatAttr(outElementType, elementValue)
|
||||
.cast<mlir::Attribute>()
|
||||
: rewriter.getIntegerAttr(outElementType, elementValue)
|
||||
.cast<mlir::Attribute>()));
|
||||
Value outTensor = rewriter
|
||||
.create<linalg::InitTensorOp>(
|
||||
loc, sizeIndex, newResultType.getElementType())
|
||||
.getResult();
|
||||
Value fillOp = rewriter.create<linalg::FillOp>(loc, constantOp, outTensor)
|
||||
.getResult(0);
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = this->getTypeConverter();
|
||||
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
|
||||
if (!getListConstructElements(op.size(), resultSizeTorchInt)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: size must be constructed using ListConstruct");
|
||||
}
|
||||
resultSize = getTypeConvertedValues(rewriter, loc, typeConverter,
|
||||
resultSizeTorchInt);
|
||||
for (auto size : resultSize)
|
||||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, fillOp);
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).template cast<RankedTensorType>();
|
||||
Type outElemType = resultType.getElementType();
|
||||
|
||||
// Create an uninitialized tensor of `resultSize` shape. It will be returned
|
||||
// without initialization/filling in the case of `AtenEmptyMemoryFormatOp`.
|
||||
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultSizeIndex, outElemType);
|
||||
|
||||
// `AtenZeros` and `AtenOnes` ops will be filled with corresponding values.
|
||||
if (std::is_same<OpTy, AtenZerosOp>::value) {
|
||||
Value zero = getConstant(rewriter, loc, 0, outElemType);
|
||||
outputTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, zero, outputTensor).getResult(0);
|
||||
} else if (std::is_same<OpTy, AtenOnesOp>::value) {
|
||||
Value one = getConstant(rewriter, loc, 1, outElemType);
|
||||
outputTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, one, outputTensor).getResult(0);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -3704,8 +3705,15 @@ public:
|
|||
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmbeddingOp>();
|
||||
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenOnesOp, AtenZerosOp>();
|
||||
patterns.add<ConvertAtenOnesZerosOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmptyMemoryFormatOp>();
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenEmptyMemoryFormatOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenZerosOp>();
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<AtenOnesOp>();
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<AtenContiguousOp>();
|
||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIntTensorOp>();
|
||||
|
|
Loading…
Reference in New Issue