Improve "pseudo" op terminology.

The term "pseudo" is very vague and was getting confusing (I felt I had
to explain it in every comment referencing it). Instead, rework the
"pseudo" ops to instead be named:

- MLIR Syntax: `torch.valsem.*`
- C++ / ODS: `ValsemVariant*Op`

This makes it clear what the concept is, and avoids confusion with other
things that might be called "pseudo", since these are very specific and
should be 100% consistently named w.r.t. the non-valsem-variant ops that
they correspond to.
pull/671/head
Sean Silva 2022-03-15 23:57:33 +00:00
parent 7ea50a537a
commit 92da4988f0
13 changed files with 79 additions and 75 deletions

View File

@ -944,7 +944,7 @@ def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
// The corresponding without underscore variant for `torch.aten.uniform_` // The corresponding without underscore variant for `torch.aten.uniform_`
// doesn't exist in the pytorch ops registry. Add it here. // doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [ def Torch_ValsemVariantAtenUniformOp: Torch_Op<"valsem.aten.uniform", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
@ -964,7 +964,7 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
// The corresponding without underscore variant for `torch.aten.bernoulli_.float` // The corresponding without underscore variant for `torch.aten.bernoulli_.float`
// doesn't exist in the pytorch ops registry. Add it here. // doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [ def Torch_ValsemVariantAtenBernoulliFloatOp: Torch_Op<"valsem.aten.bernoulli.float", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
@ -983,7 +983,7 @@ def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
// The corresponding without underscore variant for `torch.aten.bernoulli_.Tensor` // The corresponding without underscore variant for `torch.aten.bernoulli_.Tensor`
// doesn't exist in the pytorch ops registry. Add it here. // doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor", [ def Torch_ValsemVariantAtenBernoulliTensorOp: Torch_Op<"valsem.aten.bernoulli.Tensor", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
@ -1002,7 +1002,7 @@ def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor",
// The corresponding without underscore variant for `torch.aten.fill_.Scalar` // The corresponding without underscore variant for `torch.aten.fill_.Scalar`
// doesn't exist in the pytorch ops registry. Add it here. // doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenFillScalarOp: Torch_Op<"pseudo.aten.fill.Scalar", [ def Torch_ValsemVariantAtenFillScalarOp: Torch_Op<"valsem.aten.fill.Scalar", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly

View File

@ -250,7 +250,7 @@ def Torch_DeviceType : Torch_Type<"Device", "Device"> {
} }
def Torch_GeneratorType : Torch_Type<"Generator", "Generator"> { def Torch_GeneratorType : Torch_Type<"Generator", "Generator"> {
let summary = "Torch Generator for producing pseudo-random numbers"; let summary = "Torch Generator for producing valsem-random numbers";
} }
def Torch_BoolType : Torch_Type<"Bool", "bool"> { def Torch_BoolType : Torch_Type<"Bool", "bool"> {

View File

@ -59,12 +59,12 @@ public:
namespace { namespace {
class ConvertPseudoAtenUniformOp class ConvertValsemVariantAtenUniformOp
: public OpConversionPattern<PseudoAtenUniformOp> { : public OpConversionPattern<ValsemVariantAtenUniformOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(PseudoAtenUniformOp op, OpAdaptor adaptor, matchAndRewrite(ValsemVariantAtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
@ -162,6 +162,6 @@ void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenDropoutOp>(); target.addIllegalOp<AtenDropoutOp>();
patterns.add<ConvertAtenDropoutOp>(typeConverter, context); patterns.add<ConvertAtenDropoutOp>(typeConverter, context);
target.addIllegalOp<PseudoAtenUniformOp>(); target.addIllegalOp<ValsemVariantAtenUniformOp>();
patterns.add<ConvertPseudoAtenUniformOp>(typeConverter, context); patterns.add<ConvertValsemVariantAtenUniformOp>(typeConverter, context);
} }

View File

@ -80,12 +80,12 @@ public:
} // namespace } // namespace
namespace { namespace {
class ConvertPseudoAtenFillScalarOp class ConvertValsemVariantAtenFillScalarOp
: public OpConversionPattern<PseudoAtenFillScalarOp> { : public OpConversionPattern<ValsemVariantAtenFillScalarOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(PseudoAtenFillScalarOp op, OpAdaptor adaptor, matchAndRewrite(ValsemVariantAtenFillScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
@ -318,8 +318,8 @@ void mlir::torch::torch_to_linalg::
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenConstantPadNdOp>(); target.addIllegalOp<AtenConstantPadNdOp>();
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context); patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
target.addIllegalOp<PseudoAtenFillScalarOp>(); target.addIllegalOp<ValsemVariantAtenFillScalarOp>();
patterns.add<ConvertPseudoAtenFillScalarOp>(typeConverter, context); patterns.add<ConvertValsemVariantAtenFillScalarOp>(typeConverter, context);
target.addIllegalOp<AtenZerosOp, AtenOnesOp>(); target.addIllegalOp<AtenZerosOp, AtenOnesOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp, 0>>(typeConverter, patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp, 0>>(typeConverter,
context); context);

View File

@ -36,7 +36,7 @@ static Value createElementwiseLinalgGeneric(
// what happens for a single result dimension. This loop not structured that // what happens for a single result dimension. This loop not structured that
// way because it is hard to create the affine maps for each operand unless // way because it is hard to create the affine maps for each operand unless
// we structure the loop to iterate over tensor operands as the outer loop // we structure the loop to iterate over tensor operands as the outer loop
// instead of inner loop. This pseudocode gives better intuition: // instead of inner loop. This valsemcode gives better intuition:
// ``` // ```
// for each result dimension: // for each result dimension:
// for each tensor operand: // for each tensor operand:

View File

@ -127,8 +127,8 @@ static Value createInitTensor(PatternRewriter &rewriter, Location loc,
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>( Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal, loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal,
/*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal); /*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal);
return rewriter.create<PseudoAtenFillScalarOp>(loc, resultType, emptyTensor, return rewriter.create<ValsemVariantAtenFillScalarOp>(loc, resultType,
scalar); emptyTensor, scalar);
} }
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` // Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
@ -951,7 +951,7 @@ public:
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0)); rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value ub = Value ub =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0)); rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
rewriter.replaceOpWithNewOp<PseudoAtenUniformOp>( rewriter.replaceOpWithNewOp<ValsemVariantAtenUniformOp>(
op, op.getType(), input, lb, ub, /*generator=*/none); op, op.getType(), input, lb, ub, /*generator=*/none);
return success(); return success();
} }
@ -1028,11 +1028,11 @@ public:
// aten.bernoulli.float(x, p) = (rand_like(float(x)) < tensor(p)).cast(type(x)). // aten.bernoulli.float(x, p) = (rand_like(float(x)) < tensor(p)).cast(type(x)).
// Since the input x can be an integer tensor, it's important to cast it to // Since the input x can be an integer tensor, it's important to cast it to
// float type before passing it to the `aten.rand_like` op. // float type before passing it to the `aten.rand_like` op.
class DecomposePseudoAtenBernoulliFloatOp class DecomposeValsemVariantAtenBernoulliFloatOp
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> { : public OpRewritePattern<ValsemVariantAtenBernoulliFloatOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op, LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.self(); Value input = op.self();
@ -1060,11 +1060,11 @@ public:
// aten.bernoulli.Tensor(x, p) = (rand_like(float(x)) < p).cast(type(x)). // aten.bernoulli.Tensor(x, p) = (rand_like(float(x)) < p).cast(type(x)).
// Since the input x can be an integer tensor, it's important to cast it to // Since the input x can be an integer tensor, it's important to cast it to
// float type before passing it to the `aten.rand_like` op. // float type before passing it to the `aten.rand_like` op.
class DecomposePseudoAtenBernoulliTensorOp class DecomposeValsemVariantAtenBernoulliTensorOp
: public OpRewritePattern<PseudoAtenBernoulliTensorOp> { : public OpRewritePattern<ValsemVariantAtenBernoulliTensorOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PseudoAtenBernoulliTensorOp op, LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliTensorOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.self(); Value input = op.self();
@ -1208,7 +1208,7 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
Value constVal = rewriter.create<Torch::ConstantIntOp>( Value constVal = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(fillVal)); loc, rewriter.getI64IntegerAttr(fillVal));
// Initialize the allocated memory block with `fillVal`. // Initialize the allocated memory block with `fillVal`.
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>( rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
op, initTensor.getType(), initTensor, constVal); op, initTensor.getType(), initTensor, constVal);
return success(); return success();
} }
@ -1388,7 +1388,7 @@ public:
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>( Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, op.getType(), op.size(), op.dtype(), op.layout(), op.device(), loc, op.getType(), op.size(), op.dtype(), op.layout(), op.device(),
op.pin_memory(), /*memory_format=*/noneVal); op.pin_memory(), /*memory_format=*/noneVal);
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>( rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
op, op.getType(), emptyTensor, op.fill_value()); op, op.getType(), emptyTensor, op.fill_value());
return success(); return success();
} }
@ -1405,7 +1405,7 @@ public:
Value emptyTensor = rewriter.create<AtenEmptyLikeOp>( Value emptyTensor = rewriter.create<AtenEmptyLikeOp>(
op.getLoc(), op.getType(), op.self(), op.dtype(), op.layout(), op.getLoc(), op.getType(), op.self(), op.dtype(), op.layout(),
op.device(), op.pin_memory(), op.memory_format()); op.device(), op.pin_memory(), op.memory_format());
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>( rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
op, op.getType(), emptyTensor, op.fill_value()); op, op.getType(), emptyTensor, op.fill_value());
return success(); return success();
} }
@ -1491,10 +1491,10 @@ class DecomposeComplexOpsPass
target.addIllegalOp<Aten_UnsafeViewOp>(); target.addIllegalOp<Aten_UnsafeViewOp>();
patterns.add<DecomposeAtenBernoulliOp>(context); patterns.add<DecomposeAtenBernoulliOp>(context);
target.addIllegalOp<AtenBernoulliOp>(); target.addIllegalOp<AtenBernoulliOp>();
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context); patterns.add<DecomposeValsemVariantAtenBernoulliFloatOp>(context);
target.addIllegalOp<PseudoAtenBernoulliFloatOp>(); target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
patterns.add<DecomposePseudoAtenBernoulliTensorOp>(context); patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
target.addIllegalOp<PseudoAtenBernoulliTensorOp>(); target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
patterns.add<DecomposeAtenRandLikeOp>(context); patterns.add<DecomposeAtenRandLikeOp>(context);
target.addIllegalOp<AtenRandLikeOp>(); target.addIllegalOp<AtenRandLikeOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context); patterns.add<DecomposeAtenHardsigmoidOp>(context);

View File

@ -147,17 +147,17 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
Operation *newOp; Operation *newOp;
if (isa<AtenUniform_Op>(op)) { if (isa<AtenUniform_Op>(op)) {
newOp = rewriter.create<PseudoAtenUniformOp>(loc, op->getResultTypes(), newOp = rewriter.create<ValsemVariantAtenUniformOp>(
op->getOperands()); loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenBernoulli_FloatOp>(op)) { } else if (isa<AtenBernoulli_FloatOp>(op)) {
newOp = rewriter.create<PseudoAtenBernoulliFloatOp>( newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
loc, op->getResultTypes(), op->getOperands()); loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenBernoulli_TensorOp>(op)) { } else if (isa<AtenBernoulli_TensorOp>(op)) {
newOp = rewriter.create<PseudoAtenBernoulliTensorOp>( newOp = rewriter.create<ValsemVariantAtenBernoulliTensorOp>(
loc, op->getResultTypes(), op->getOperands()); loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenFill_ScalarOp>(op)) { } else if (isa<AtenFill_ScalarOp>(op)) {
newOp = rewriter.create<PseudoAtenFillScalarOp>(loc, op->getResultTypes(), newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
op->getOperands()); loc, op->getResultTypes(), op->getOperands());
} else { } else {
return failure(); return failure();
} }

View File

@ -505,10 +505,10 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp, AtenAbsOp, AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp, AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
PseudoAtenBernoulliFloatOp, PseudoAtenBernoulliTensorOp, ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
PseudoAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenErfOp, AtenSiluOp, AtenHardtanhOp, AtenHardswishOp, AtenErfOp, AtenSiluOp, AtenHardtanhOp,
AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp,
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp, AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,

View File

@ -232,11 +232,12 @@ class ReifyShapeCalculationsPass
module.walk([&](Operation *op) { module.walk([&](Operation *op) {
Location loc = op->getLoc(); Location loc = op->getLoc();
auto name = op->getName().stripDialect(); auto name = op->getName().stripDialect();
// For pseudo-ops (ops that are mechanically consistent with existing // For value-semantic variant ops, i.e. valsem-ops (ops that are
// torch conventions, but simply not present, such as a missing in-place // mechanically consistent with existing torch conventions of in-place vs.
// or out-of-place variant), remove the pseudo prefix. // out-of-place (value-semantic) variants), remove the prefix when
if (name.startswith("pseudo.")) // looking them up in the shape library.
name = name.drop_front(strlen("pseudo.")); if (name.startswith("valsem."))
name = name.drop_front(strlen("valsem."));
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str(); auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
auto shapeFunction = auto shapeFunction =
shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName); shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName);

View File

@ -263,11 +263,14 @@ def check_shape_function(invocations: List[Invocation]):
def not_present_in_registry(f): def not_present_in_registry(f):
"""Decorator for shape functions not present in the shape registry. """Decorator for shape functions not present in the shape registry.
This can happen for "pseudo" ops that we have in Torch-MLIR, such as This can happen for "valsem" ops that we have in Torch-MLIR, such as
torch.aten.fill.Scalar, which are consistent with PyTorch conventions (e.g. torch.valsem.aten.fill.Scalar, which are consistent with PyTorch conventions
being the value-semantic correspondent of torch.aten.fill_.Scalar), but (e.g. being the value-semantic correspondent of torch.aten.fill_.Scalar),
that for whatever reason are not present in PyTorch. Such ops are useful but that for whatever reason are not present in PyTorch. Such ops are useful
to keep certain passes within Torch-MLIR as consistent as possible. to keep certain passes within Torch-MLIR as consistent as possible.
For such ops, in the shape library generator, we treat them as if they
were registered torch ops (so we don't put "valsem" on them), to keep the
generator consistent.
To check if this decorator has been applied, use To check if this decorator has been applied, use
`hasattr(f, "_not_present_in_registry")`. `hasattr(f, "_not_present_in_registry")`.

View File

@ -388,7 +388,7 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
// CHECK: %[[NONE_2:.*]] = torch.constant.none // CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> // CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 // CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false // CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
@ -405,7 +405,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.float // CHECK-LABEL: func @torch.valsem.aten.bernoulli.float
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { // CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01 // CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01
@ -419,7 +419,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
// CHECK: %[[NONE_2:.*]] = torch.constant.none // CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1> // CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 // CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false // CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
@ -428,16 +428,16 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%prob = torch.constant.float 4.000000e-01 %prob = torch.constant.float 4.000000e-01
%0 = torch.pseudo.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> %0 = torch.valsem.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
return %1 : !torch.vtensor return %1 : !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.Tensor( // CHECK-LABEL: func @torch.valsem.aten.bernoulli.Tensor(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>, // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>,
// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { // CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -450,7 +450,7 @@ func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
// CHECK: %[[NONE_2:.*]] = torch.constant.none // CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> // CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 // CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false // CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
@ -459,9 +459,9 @@ func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @torch.pseudo.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%0 = torch.pseudo.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64> %0 = torch.valsem.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
return %1 : !torch.vtensor return %1 : !torch.vtensor
} }
@ -496,13 +496,13 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : // CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[CST1_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> // CHECK: %[[CST1_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[CST1_TENSOR]], %[[DIV]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[MIN:.*]] = torch.aten.minimum %[[CST1_TENSOR]], %[[DIV]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int> // CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_1:.*]] = torch.constant.none // CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[EMPTY_1:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : // CHECK: %[[EMPTY_1:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[CST0_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> // CHECK: %[[CST0_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
@ -524,7 +524,7 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : // CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FILL:.*]] = torch.pseudo.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> // CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
@ -543,13 +543,13 @@ func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : // CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> // CHECK: %[[MIN_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.maximum %[[INPUT]], %[[MIN_TENSOR]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?],f32> // CHECK: %[[MIN:.*]] = torch.aten.maximum %[[INPUT]], %[[MIN_TENSOR]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?],f32>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int> // CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[VAL_10:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : // CHECK: %[[VAL_10:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MAX_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> // CHECK: %[[MAX_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32> // CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?],f32> // CHECK: return %[[RET]] : !torch.vtensor<[?],f32>
func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> { func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> {
@ -616,7 +616,7 @@ func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.v
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MEM_FORMAT:.*]] = torch.constant.none // CHECK: %[[MEM_FORMAT:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> // CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[RES:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32> // CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32> // CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
func @torch.aten.full() -> !torch.vtensor<[2,3],f32> { func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
%float5.000000e00 = torch.constant.float 5.000000e+00 %float5.000000e00 = torch.constant.float 5.000000e+00
@ -639,7 +639,7 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INP]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int // CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INP]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> // CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[RES:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int5 = torch.constant.int 5 %int5 = torch.constant.int 5

View File

@ -134,7 +134,7 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float, // CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor { // CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor // CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.uniform %[[T_VTENSOR]], %[[MIN]], %[[MAX]], %[[GENERATOR]] : // CHECK: %[[VRET:.*]] = torch.valsem.aten.uniform %[[T_VTENSOR]], %[[MIN]], %[[MAX]], %[[GENERATOR]] :
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor // CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
@ -150,7 +150,7 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
// CHECK: %[[GENERATOR:.*]] = torch.constant.none // CHECK: %[[GENERATOR:.*]] = torch.constant.none
// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01 // CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor // CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor // CHECK: %[[VRET:.*]] = torch.valsem.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
@ -166,7 +166,7 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[VALUE:.*]] = torch.constant.int 1 // CHECK: %[[VALUE:.*]] = torch.constant.int 1
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor // CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[VRET:.*]] = torch.valsem.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor

View File

@ -24,11 +24,11 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// CHECK: module { // CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.fill.Scalar( // CHECK: func private @__torch_mlir_shape_fn.aten.fill.Scalar(
// CHECK-LABEL: func @pseudo_ops( // CHECK-LABEL: func @valsem_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[VALUE:.*]] = torch.pseudo.aten.fill.Scalar %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[VALUE:.*]] = torch.valsem.aten.fill.Scalar %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[VALUE]] : !torch.vtensor // CHECK: torch.shape.calculate.yield %[[VALUE]] : !torch.vtensor
// CHECK: } shapes { // CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
@ -36,8 +36,8 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor { func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%0 = torch.pseudo.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor %0 = torch.valsem.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
@ -52,7 +52,7 @@ func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// CHECK-SAME: %[[ARG1:.*]]: !torch.float) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.float) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[UNIFORM:.*]] = torch.pseudo.aten.uniform %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[NONE]] : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor // CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[NONE]] : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[UNIFORM]] : !torch.vtensor // CHECK: torch.shape.calculate.yield %[[UNIFORM]] : !torch.vtensor
// CHECK: } shapes { // CHECK: } shapes {
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
@ -63,7 +63,7 @@ func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor { func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%0 = torch.pseudo.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor %0 = torch.valsem.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }