From e57d3f977450d5e516b360e236e62a0ab81ef908 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Fri, 25 Feb 2022 22:05:04 +0530 Subject: [PATCH] [LINALG] Fix `aten.bernoulli` op lowering - This commit adds E2E support for `aten.rand_like` and `aten.bernoulli_.Tensor` ops. - The `aten.bernoulli(x)` was implemented as: `aten.bernoulli(x) = rand_like(x) < 0.5`, assuming 0.5 as default probability, whereas according to the pytorch documentation: https://pytorch.org/docs/stable/generated/torch.bernoulli.html#torch.bernoulli the input x in `aten.bernoulli(x)` is itself a tensor containing probabilities to be used for drawing the binary random number. - So this commit fixes the `aten.bernoulli(x)` implementation as: `aten.bernoulli(x) = rand_like(x) < x`. - It also fixes the case where the input to `aten.bernoulli_.float` is an integer tensor. In this case the input must be casted to float type before passing it as operand to `aten.rand_like` op. `aten.bernoulli_.float(x, p) = rand_like(float(x)) < p`. Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/rng.py | 133 +++++++++-- .../Dialect/Torch/IR/GeneratedAtenOps.td | 34 +++ .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 18 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 226 ++++++++++++------ .../Torch/Transforms/ReduceOpVariants.cpp | 4 + lib/Dialect/Torch/Transforms/RefineTypes.cpp | 10 +- .../jit_ir/build_tools/torch_ods_gen.py | 2 + test/Dialect/Torch/decompose-complex-ops.mlir | 96 ++++++-- 8 files changed, 420 insertions(+), 103 deletions(-) diff --git a/e2e_testing/torchscript/rng.py b/e2e_testing/torchscript/rng.py index 68b280358..ca3133229 100644 --- a/e2e_testing/torchscript/rng.py +++ b/e2e_testing/torchscript/rng.py @@ -4,7 +4,6 @@ from torch_mlir_e2e_test.torchscript.framework import TestUtils from torch_mlir_e2e_test.torchscript.registry import register_test_case from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export - # ============================================================================== class UniformModule(torch.nn.Module): @@ -90,17 +89,73 @@ class BernoulliModule(torch.nn.Module): @export @annotate_args([ None, - ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), ]) - def forward(self, a): - x = torch.bernoulli(a) - mean = torch.mean(x) - std = torch.std(x) + def forward(self, x, y, z): + a = torch.bernoulli(x) + b = torch.bernoulli(y) + c = torch.bernoulli(z) + mean = torch.cat([ + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + torch.flatten(torch.mean(c)) + ]) + std = torch.cat([ + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + torch.flatten(torch.std(c)) + ]) return mean, std + @register_test_case(module_factory=lambda: BernoulliModule()) def BernoulliModule_basic(module, tu: TestUtils): - module.forward(tu.rand(256, 512, 64)) + module.forward( + tu.rand(256, 512, 8).double(), + tu.rand(512, 1024, 4).double(), + tu.rand(512, 256, 4).double()) + +# ============================================================================== + +class BernoulliZerosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True), + ]) + def forward(self, x): + return torch.bernoulli(x) + + +@register_test_case(module_factory=lambda: BernoulliZerosModule()) +def BernoulliZerosModule_basic(module, tu: TestUtils): + module.forward(torch.zeros(4, 8).double()) + +# ============================================================================== + +class BernoulliOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True), + ]) + def forward(self, x): + return torch.bernoulli(x) + + +@register_test_case(module_factory=lambda: BernoulliOnesModule()) +def BernoulliOnesModule_basic(module, tu: TestUtils): + module.forward(torch.ones(4, 8).double()) + +# ============================================================================== class BernoulliFloatModule(torch.nn.Module): def __init__(self): @@ -113,25 +168,69 @@ class BernoulliFloatModule(torch.nn.Module): ([-1, -1, -1], torch.float64, True), ([-1, -1, -1], torch.float64, True), ]) - def forward(self, a, b, c): - x = torch.ops.aten.bernoulli_(a, 0.4) - y = torch.ops.aten.bernoulli_(b, 0.7) - z = torch.ops.aten.bernoulli_(c, 0.5) + def forward(self, x, y, z): + a = torch.ops.aten.bernoulli_(x, 0.4) + b = torch.ops.aten.bernoulli_(y, 0.7) + c = torch.ops.aten.bernoulli_(z, 0.5) mean = torch.cat([ - torch.flatten(torch.mean(x)), - torch.flatten(torch.mean(y)), - torch.flatten(torch.mean(z)) + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + torch.flatten(torch.mean(c)) ]) std = torch.cat([ - torch.flatten(torch.std(x)), - torch.flatten(torch.std(y)), - torch.flatten(torch.std(z)) + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + torch.flatten(torch.std(c)) ]) return mean, std + @register_test_case(module_factory=lambda: BernoulliFloatModule()) def BernoulliFloatModule_basic(module, tu: TestUtils): module.forward( tu.rand(256, 512, 8).double(), tu.rand(512, 1024, 4).double(), tu.rand(512, 256, 4).double()) + +# ============================================================================== + +class BernoulliTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, x, px, y, py, z, pz): + a = torch.ops.aten.bernoulli_(x, px) + b = torch.ops.aten.bernoulli_(y, py) + c = torch.ops.aten.bernoulli_(z, pz) + mean = torch.cat([ + torch.flatten(torch.mean(a)), + torch.flatten(torch.mean(b)), + torch.flatten(torch.mean(c)) + ]) + std = torch.cat([ + torch.flatten(torch.std(a)), + torch.flatten(torch.std(b)), + torch.flatten(torch.std(c)) + ]) + return mean, std + + +@register_test_case(module_factory=lambda: BernoulliTensorModule()) +def BernoulliTensorModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(512, 512, 8).double(), + tu.rand(512, 512, 8).double(), + tu.rand(512, 1024, 4).double(), + tu.rand(512, 1024, 4).double(), + tu.rand(512, 256, 4).double(), + tu.rand(512, 256, 4).double()) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index af808c820..5b03eaa4a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1502,6 +1502,25 @@ def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [ let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($from)) `,` qualified(type($to)) `,` qualified(type($generator)) `->` qualified(type($result))"; } +def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchOptionalIntType:$dtype, + TorchOptionalIntType:$layout, + TorchOptionalDeviceType:$device, + TorchOptionalBoolType:$pin_memory, + TorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` qualified(type($self)) `,` qualified(type($dtype)) `,` qualified(type($layout)) `,` qualified(type($device)) `,` qualified(type($pin_memory)) `,` qualified(type($memory_format)) `->` qualified(type($result))"; +} + def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [ AllowsTypeRefinement, HasValueSemantics @@ -1532,6 +1551,21 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [ let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))"; } +def Torch_AtenBernoulli_TensorOp : Torch_Op<"aten.bernoulli_.Tensor", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$p, + TorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))"; +} + def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ AllowsTypeRefinement, HasValueSemantics diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 1ce6d2b86..fa37b55c2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -953,6 +953,24 @@ def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [ let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)"; } +// The corresponding without underscore variant for `torch.aten.bernoulli_.Tensor` +// doesn't exist in the pytorch ops registry. Add it here. +def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ]> { + let summary = "Generated op for `aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$p, + TorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($p)) `,` qualified(type($generator)) `->` qualified(type($result))"; +} + // The corresponding without underscore variant for `torch.aten.fill_.Scalar` // doesn't exist in the pytorch ops registry. Add it here. def Torch_PseudoAtenFillScalarOp: Torch_Op<"pseudo.aten.fill.Scalar", [ diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b4fa7d607..41cea1210 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -131,8 +131,8 @@ static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Type dtype) { BaseTensorType origType = input.getType().cast(); Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype); - // `convertIntVal` contains the corresponding integer for the dtype which is used - // by the aten.to.dtype op. + // `convertIntVal` contains the corresponding integer for the dtype which is + // used by the aten.to.dtype op. Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype); Value falseVal = rewriter.create(loc, false); Value noneVal = rewriter.create(loc); @@ -141,22 +141,21 @@ static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, return converted; } -// Helper to create a tensor filled with the given scalar. Scalar would be -// converted the to the element type of the given tensor type. +// Helper to create a tensor filled with the given `scalar`. `scalar` would be +// converted to the element type of the given `resultType`. static Value createInitTensor(PatternRewriter &rewriter, Location loc, Type resultType, Value scalar, Value sizeList) { BaseTensorType tensorType = resultType.cast(); Value noneVal = rewriter.create(loc); Value emptyTensor = rewriter.create( loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal, - /*device=*/noneVal, - /*pin_memory=*/noneVal, /*memory_format=*/noneVal); + /*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal); return rewriter.create(loc, resultType, emptyTensor, scalar); } -// Helper to create a rank0 tensor filled with the given scalar. Scalar would be -// converted the to the element type of the given tensor type. +// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` +// would be converted to the element type of the given `inputType`. static Value createRank0Tensor(PatternRewriter &rewriter, Location loc, BaseTensorType inputType, Value scalar) { SmallVector sizes; @@ -930,66 +929,125 @@ public: }; } // namespace -// Returns a tensor with bernoulli(p) distribution. -// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p). -static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, - Operation *op, Location loc, - Value input, double p, - Value &result) { - BaseTensorType inputType = input.getType().cast(); - if (!inputType.hasSizes() || !inputType.hasDtype()) { - return rewriter.notifyMatchFailure( - op, "Can't decomposeBernoulliLikeOp without sizes or dtype"); - } - BaseTensorType boolType = - inputType - .getWithSizesAndDtype( - inputType.getSizes(), - IntegerType::get(op->getContext(), 1, IntegerType::Signless)) - .cast(); - Value prob = - rewriter.create(loc, rewriter.getF64FloatAttr(p)); - Value lb = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); - Value ub = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - - Value noneVal = rewriter.create(loc); - // Create a uniform random op with low and high set to lb and ub respectively. - Value uniformRandom = rewriter.create( - loc, inputType, input, lb, ub, noneVal); - Value gtValue = - rewriter.create(loc, boolType, uniformRandom, prob); - // Since `gtValue` will be a boolean tensor convert it back to the original - // type. - result = convertTensorToDtype(rewriter, loc, gtValue, inputType.getDtype()); - return success(); -} - namespace { -class DecomposeAtenBernoulliOp : public OpRewritePattern { +class DecomposeAtenRandLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenBernoulliOp op, + LogicalResult matchAndRewrite(AtenRandLikeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value self = op.self(); - Value generator = op.generator(); - if (!generator.getType().isa()) + Value input = op.self(); + auto inputType = input.getType().cast(); + if (!inputType.hasDtype() || !inputType.getDtype().isa()) + return rewriter.notifyMatchFailure(op, + "only support floating-point type"); + + // TODO: Add support for layout, pin_memory and memory_format features. + // Only `none` layout is supported. + if (!op.layout().getType().isa()) return rewriter.notifyMatchFailure( - op, "The generator has to ben None because only global default " - "generator is supported"); - Value result; - if (failed(decomposeBernoulliLikeOp(rewriter, op, loc, self, /*p=*/0.5, - result))) - return failure(); - rewriter.replaceOp(op, result); + op, "unimplemented: only default layout is supported"); + + // The pin_memory should be either `none` or constant `False`. + if (!op.pin_memory().getType().isa()) { + bool pinMemory; + if (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory))) + return rewriter.notifyMatchFailure( + op, "unimplemented: pin_memory must be a constant"); + else if (pinMemory) + return rewriter.notifyMatchFailure( + op, "unimplemented: pin_memory is expected to be false"); + } + + // Only `none` memory_format is supported. + if (!op.memory_format().getType().isa()) + return rewriter.notifyMatchFailure( + op, "unimplemented: only default memory format is supported"); + + // Create a uniform random op with low and high set to 0.0 and 1.0 + // respectively. + Value none = rewriter.create(loc); + Value lb = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value ub = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, lb, ub, /*generator=*/none); return success(); } }; } // namespace namespace { +// Bernoulli(x, p) = (rand_like(float(x)) < p).cast(type(x)). Here, +// 1. p must be a float tensor. +// 2. The shape of p should be broadcastable to the shape of x. +// 3. Bernoulli(x, p) returns a tensor of the same type as that of x. +static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, + Operation *op, Location loc, + Value input, Value prob, + Value &output) { + auto inputType = input.getType().cast(); + auto probType = prob.getType().cast(); + // Both the `input` and `prob` must be ranked tensors. + if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() || + !probType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "can't decompose bernoulli like ops without sizes or dtype"); + } + // The `prob` is expected to be a float type tensor. + if (!probType.getDtype().isa()) { + return rewriter.notifyMatchFailure( + op, "probabilities must be a float type tensor"); + } + + // Since the `aten.rand_like` op expects float-type operand, create a + // float-type tensor with the same shape as that of the `input`. + Value floatTensor = + convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type()); + Value none = rewriter.create(loc); + Value randomVal = rewriter.create( + loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); + + // Bernoulli(x, p) = rand_like(float(x)) < p. + auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), + rewriter.getI1Type()); + Value lessThanP = + rewriter.create(loc, boolResType, randomVal, prob); + + // As the `output` is expected to be of the `input` type, convert the boolean + // tensor `lessThanP` to a `input` type tensor. + output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype()); + return success(); +} + +// aten.bernoulli(x) = rand_like(x) < x. Here, the input x is a tensor +// containing probabilities to be used for drawing the binary random number. +class DecomposeAtenBernoulliOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBernoulliOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.self(); + if (!op.generator().getType().isa()) + return rewriter.notifyMatchFailure( + op, "The generator has to ben None because only global default " + "generator is supported"); + Value output; + if (failed( + decomposeBernoulliLikeOp(rewriter, op, loc, input, input, output))) + return rewriter.notifyMatchFailure( + op, "decomposeBernoulliLikeOp failed to decompose the op"); + rewriter.replaceOp(op, output); + return success(); + } +}; + +// aten.bernoulli.float(x, p) = (rand_like(float(x)) < tensor(p)).cast(type(x)). +// Since the input x can be an integer tensor, it's important to cast it to +// float type before passing it to the `aten.rand_like` op. class DecomposePseudoAtenBernoulliFloatOp : public OpRewritePattern { public: @@ -997,20 +1055,50 @@ public: LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value self = op.self(); - Value generator = op.generator(); - double p; - if (!matchPattern(op.p(), m_TorchConstantFloat(&p))) - return rewriter.notifyMatchFailure(op, "p should be constant float"); - - if (!generator.getType().isa()) + Value input = op.self(); + Value p = op.p(); + if (!op.generator().getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to ben None because only global default " "generator is supported"); - Value result; - if (failed(decomposeBernoulliLikeOp(rewriter, op, loc, self, p, result))) - return failure(); - rewriter.replaceOp(op, result); + + auto inputType = input.getType().cast(); + SmallVector empty; + Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty), + rewriter.getF64Type()); + Value prob = rewriter.create(loc, tensorType, p); + Value output; + if (failed( + decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output))) + return rewriter.notifyMatchFailure( + op, "decomposeBernoulliLikeOp failed to decompose the op"); + rewriter.replaceOp(op, output); + return success(); + } +}; + +// aten.bernoulli.Tensor(x, p) = (rand_like(float(x)) < p).cast(type(x)). +// Since the input x can be an integer tensor, it's important to cast it to +// float type before passing it to the `aten.rand_like` op. +class DecomposePseudoAtenBernoulliTensorOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PseudoAtenBernoulliTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.self(); + Value prob = op.p(); + if (!op.generator().getType().isa()) + return rewriter.notifyMatchFailure( + op, "The generator has to ben None because only global default " + "generator is supported"); + Value output; + if (failed( + decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output))) + return rewriter.notifyMatchFailure( + op, "decomposeBernoulliLikeOp failed to decompose the op"); + rewriter.replaceOp(op, output); return success(); } }; @@ -1425,6 +1513,10 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 79fbc7ad0..d3c5cc6ca 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -152,6 +152,9 @@ public: } else if (isa(op)) { newOp = rewriter.create( loc, op->getResultTypes(), op->getOperands()); + } else if (isa(op)) { + newOp = rewriter.create( + loc, op->getResultTypes(), op->getOperands()); } else if (isa(op)) { newOp = rewriter.create(loc, op->getResultTypes(), op->getOperands()); @@ -232,6 +235,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *op) { if (op->hasTrait()) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8995c0fa7..11958813a 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -229,10 +229,10 @@ public: AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp, - AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp, - PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp, - AtenHardsigmoidOp, AtenHardswishOp, AtenSiluOp, AtenHardtanhOp>( - op)) { + AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp, + PseudoAtenBernoulliFloatOp, PseudoAtenBernoulliTensorOp, + PseudoAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, + AtenHardswishOp, AtenSiluOp, AtenHardtanhOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } @@ -399,6 +399,8 @@ public: return visitConstantTensorNewLikeOp(newZeros, operands); } else if (auto newOnes = dyn_cast(op)) { return visitConstantTensorNewLikeOp(newOnes, operands); + } else if (auto randLike = dyn_cast(op)) { + return visitConstantTensorAllocLikeOp(randLike, operands); } else if (auto toDtype = dyn_cast(op)) { return visitAtenToDtypeOp(toDtype, operands); } else if (auto toOther = dyn_cast(op)) { diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 22ec093a3..41f299375 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -508,8 +508,10 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): # underscore variant doesn't exist. emit("aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)") + emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") + emit("aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 334a22f2d..8588133b8 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -377,26 +377,92 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) - // ----- // CHECK-LABEL: func @torch.aten.bernoulli -// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor { +// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { // CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FLOAT0_5:.*]] = torch.constant.float 5.000000e-01 +// CHECK: %[[INT7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[NONE_2:.*]] = torch.constant.none // CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[NONE0:.*]] = torch.constant.none -// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[INP]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE0]] : -// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[GT:.*]] = torch.aten.lt.Scalar %[[UNF]], %[[FLOAT0_5]] : !torch.vtensor<[?,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],i1> -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE1:.*]] = torch.constant.none -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE1]] : -// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor +// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> +// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false +// CHECK: %[[NONE_3:.*]] = torch.constant.none +// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor -func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor { +func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { %none = torch.constant.none - %0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[?,?,?],f32> - %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor + %0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64> + %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor + return %1 : !torch.vtensor +} + +// ----- +// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.float +// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01 +// CHECK: %[[PROB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[PROB]] : !torch.float -> !torch.vtensor<[],f64> +// CHECK: %[[INT7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[NONE_2:.*]] = torch.constant.none +// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1> +// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false +// CHECK: %[[NONE_3:.*]] = torch.constant.none +// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor +func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { + %none = torch.constant.none + %prob = torch.constant.float 4.000000e-01 + %0 = torch.pseudo.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> + %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor + return %1 : !torch.vtensor +} + +// ----- +// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.Tensor( +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>, +// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[NONE_2:.*]] = torch.constant.none +// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> +// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false +// CHECK: %[[NONE_3:.*]] = torch.constant.none +// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor +func @torch.pseudo.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { + %none = torch.constant.none + %0 = torch.pseudo.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64> + %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor return %1 : !torch.vtensor }