AtenCumprodOp (#3737)

pull/3755/head
Xida Ren (Cedar) 2024-09-26 18:17:22 -04:00 committed by GitHub
parent 335cf5f6d0
commit 9938abf25e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 229 additions and 0 deletions

View File

@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy); Type elemTy);
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy);
Value castIntToIndex(OpBuilder &b, Location loc, Value v); Value castIntToIndex(OpBuilder &b, Location loc, Value v);

View File

@ -1497,6 +1497,79 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenCumprodOp : public OpConversionPattern<AtenCumprodOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type elementType = resultType.getElementType();
Type inputElementType =
cast<RankedTensorType>(input.getType()).getElementType();
// Converting the input element type to the result's element type.
// The only possible mismatch would be when the input element type is an
// integer but not `si64`. Therefore, we directly convert the input to
// `si64`. Rest all cases are handled in the dtype definition for this op.
if (elementType != inputElementType) {
Value torchInput = convertTensorToDtype(
rewriter, loc, op.getSelf(),
rewriter.getIntegerType(64, IntegerType::Signed));
input = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(torchInput.getType()),
torchInput);
}
int64_t inputRank = resultType.getRank();
Value dtype = op.getDtype();
if (!isa<Torch::NoneType>(dtype.getType()))
return rewriter.notifyMatchFailure(
op, "unsupported: dtype argument not supported");
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant dim value is supported");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "invalid dim");
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, input);
Value output = createOneInitTensor(rewriter, loc, sizes, elementType);
output = rewriter.create<tensor::CastOp>(loc, resultType, output);
SmallVector<Value> accSizes(sizes);
accSizes.erase(accSizes.begin() + dim);
SmallVector<int64_t> accStatic(
makeShapeTorchCompatible(resultType.getShape()));
accStatic.erase(accStatic.begin() + dim);
Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType);
Type accType =
RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType);
acc = rewriter.create<tensor::CastOp>(loc, accType, acc);
Value result = createTMTensorScanOp(
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
[](OpBuilder &b, Location loc, Value input, Value acc) {
Value prod =
(isa<mlir::FloatType>(input.getType())
? b.create<arith::MulFOp>(loc, input, acc)->getResult(0)
: b.create<arith::MulIOp>(loc, input, acc)->getResult(0));
b.create<TMTensor::YieldOp>(loc, prod);
});
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> { class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
public: public:
@ -2240,6 +2313,8 @@ public:
patterns.add<ConvertAtenSortOp>(typeConverter, context); patterns.add<ConvertAtenSortOp>(typeConverter, context);
target.addIllegalOp<AtenCumsumOp>(); target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenCumsumOp>(typeConverter, context); patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
target.addIllegalOp<AtenCumprodOp>();
patterns.add<ConvertAtenCumprodOp>(typeConverter, context);
target.addIllegalOp<AtenScaledDotProductAttentionOp>(); target.addIllegalOp<AtenScaledDotProductAttentionOp>();
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter, patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
context); context);

View File

@ -138,6 +138,16 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0); return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
} }
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
Value c1 =
b.create<arith::ConstantOp>(loc, b.getOneAttr(type.getElementType()));
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
}
Value castIntToIndex(OpBuilder &b, Location loc, Value v) { Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(isa<IntegerType>(v.getType()) && "must be called with integer type"); assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v); return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);

View File

@ -9134,6 +9134,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
@ -11844,6 +11847,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %1 : !torch.int\n" " return %1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int4 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %2#1 : !torch.int\n"
" }\n"
" torch.prim.If.yield %4 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"

View File

@ -79,6 +79,7 @@ TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors #### General TorchDynamo/PyTorch errors
# torch._dynamo.exc.Unsupported: Tensor.item # torch._dynamo.exc.Unsupported: Tensor.item
"CumsumModule_basic", "CumsumModule_basic",
"CumprodModule_basic",
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
# RuntimeError: Failed running call_function aten.convolution_backward(... # RuntimeError: Failed running call_function aten.convolution_backward(...
# https://github.com/pytorch/pytorch/issues/89629 # https://github.com/pytorch/pytorch/issues/89629
@ -432,6 +433,7 @@ FX_IMPORTER_XFAIL_SET = {
"ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2D_basic",
"CumsumModule_basic", "CumsumModule_basic",
"CumprodModule_basic",
"DeformConv2D_basic", "DeformConv2D_basic",
"DivFloatModule_basic", "DivFloatModule_basic",
"DivIntModule_basic", "DivIntModule_basic",
@ -667,6 +669,10 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2D_basic",
"CumsumModule_basic", "CumsumModule_basic",
"CumprodModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"DeformConv2D_basic", "DeformConv2D_basic",
"DeterminantBatchedModule_F32", "DeterminantBatchedModule_F32",
"DeterminantDynamicModule_F32", "DeterminantDynamicModule_F32",
@ -1077,6 +1083,9 @@ STABLEHLO_PASS_SET = {
"CumsumInputDtypeInt32Module_basic", "CumsumInputDtypeInt32Module_basic",
"CumsumStaticModule_basic", "CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic", "CumsumStaticNegativeDimModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"DetachModule_basic", "DetachModule_basic",
"DivFloatModule_basic", "DivFloatModule_basic",
"DivIntModule_basic", "DivIntModule_basic",
@ -3105,6 +3114,10 @@ ONNX_XFAIL_SET = {
"CopyWithDifferentDTypesModule_basic", "CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticBroadcastModule_basic",
"CumsumInputDtypeInt32Module_basic", "CumsumInputDtypeInt32Module_basic",
"CumprodModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"ElementwiseAcosIntModule_basic", "ElementwiseAcosIntModule_basic",
"ElementwiseAsinIntModule_basic", "ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanTensorIntModule_basic",
@ -3378,6 +3391,10 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"CumsumModule_basic", "CumsumModule_basic",
"CumsumStaticModule_basic", "CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic", "CumsumStaticNegativeDimModule_basic",
"CumprodModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"DeformConv2D_basic", "DeformConv2D_basic",
"DeterminantBatchedModule_F32", "DeterminantBatchedModule_F32",
"DeterminantDynamicModule_F32", "DeterminantDynamicModule_F32",
@ -4110,6 +4127,10 @@ ONNX_TOSA_XFAIL_SET = {
"CumsumModule_basic", "CumsumModule_basic",
"CumsumStaticModule_basic", "CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic", "CumsumStaticNegativeDimModule_basic",
"CumprodModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"DeformConv2D_basic", "DeformConv2D_basic",
"DeterminantModule_F32", "DeterminantModule_F32",
"DeterminantBatchedModule_F32", "DeterminantBatchedModule_F32",

View File

@ -1434,6 +1434,9 @@ def atenmultinomial〡shape(self: List[int], num_samples: int, replacement: b
def atencumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: def atencumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
return self return self
def atencumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
return self
def atenrand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: def atenrand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self return self
@ -2926,6 +2929,18 @@ def atencumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
return torch.int64 return torch.int64
return self_dtype return self_dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32))
def atencumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
if dtype is not None:
return dtype
self_rank, self_dtype = self_rank_dtype
if is_integer_dtype(self_dtype):
return torch.int64
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atendetach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def atendetach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -4830,6 +4830,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class CumprodModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, val):
ones = torch.ones([1], dtype=torch.int32)
return torch.ops.aten.cumprod(val, ones.item())
@register_test_case(module_factory=lambda: CumprodModule())
def CumprodModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
class CumprodStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 7, 4], torch.float32, True),
]
)
def forward(self, val):
return torch.ops.aten.cumprod(val, 1)
@register_test_case(module_factory=lambda: CumprodStaticModule())
def CumprodStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
class CumprodStaticNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 7, 4], torch.float32, True),
]
)
def forward(self, val):
return torch.ops.aten.cumprod(val, dim=-1)
@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule())
def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
class CumprodInputDtypeInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 7, 4], torch.int32, True),
]
)
def forward(self, val):
return torch.ops.aten.cumprod(val, 1)
@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module())
def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 7, 4).to(torch.int32))
# ==============================================================================
class AtenToDeviceModule(torch.nn.Module): class AtenToDeviceModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()