Improve "pseudo" op terminology.

The term "pseudo" is very vague and was getting confusing (I felt I had
to explain it in every comment referencing it). Instead, rework the
"pseudo" ops to instead be named:

- MLIR Syntax: `torch.valsem.*`
- C++ / ODS: `ValsemVariant*Op`

This makes it clear what the concept is, and avoids confusion with other
things that might be called "pseudo", since these are very specific and
should be 100% consistently named w.r.t. the non-valsem-variant ops that
they correspond to.
pull/671/head
Sean Silva 2022-03-15 23:57:33 +00:00
parent 7ea50a537a
commit 92da4988f0
13 changed files with 79 additions and 75 deletions

View File

@ -944,7 +944,7 @@ def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
// The corresponding without underscore variant for `torch.aten.uniform_`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
def Torch_ValsemVariantAtenUniformOp: Torch_Op<"valsem.aten.uniform", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
@ -964,7 +964,7 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
// The corresponding without underscore variant for `torch.aten.bernoulli_.float`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
def Torch_ValsemVariantAtenBernoulliFloatOp: Torch_Op<"valsem.aten.bernoulli.float", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
@ -983,7 +983,7 @@ def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
// The corresponding without underscore variant for `torch.aten.bernoulli_.Tensor`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor", [
def Torch_ValsemVariantAtenBernoulliTensorOp: Torch_Op<"valsem.aten.bernoulli.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
@ -1002,7 +1002,7 @@ def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor",
// The corresponding without underscore variant for `torch.aten.fill_.Scalar`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenFillScalarOp: Torch_Op<"pseudo.aten.fill.Scalar", [
def Torch_ValsemVariantAtenFillScalarOp: Torch_Op<"valsem.aten.fill.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly

View File

@ -250,7 +250,7 @@ def Torch_DeviceType : Torch_Type<"Device", "Device"> {
}
def Torch_GeneratorType : Torch_Type<"Generator", "Generator"> {
let summary = "Torch Generator for producing pseudo-random numbers";
let summary = "Torch Generator for producing valsem-random numbers";
}
def Torch_BoolType : Torch_Type<"Bool", "bool"> {

View File

@ -59,12 +59,12 @@ public:
namespace {
class ConvertPseudoAtenUniformOp
: public OpConversionPattern<PseudoAtenUniformOp> {
class ConvertValsemVariantAtenUniformOp
: public OpConversionPattern<ValsemVariantAtenUniformOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PseudoAtenUniformOp op, OpAdaptor adaptor,
matchAndRewrite(ValsemVariantAtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
@ -162,6 +162,6 @@ void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenDropoutOp>();
patterns.add<ConvertAtenDropoutOp>(typeConverter, context);
target.addIllegalOp<PseudoAtenUniformOp>();
patterns.add<ConvertPseudoAtenUniformOp>(typeConverter, context);
target.addIllegalOp<ValsemVariantAtenUniformOp>();
patterns.add<ConvertValsemVariantAtenUniformOp>(typeConverter, context);
}

View File

@ -80,12 +80,12 @@ public:
} // namespace
namespace {
class ConvertPseudoAtenFillScalarOp
: public OpConversionPattern<PseudoAtenFillScalarOp> {
class ConvertValsemVariantAtenFillScalarOp
: public OpConversionPattern<ValsemVariantAtenFillScalarOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PseudoAtenFillScalarOp op, OpAdaptor adaptor,
matchAndRewrite(ValsemVariantAtenFillScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
@ -318,8 +318,8 @@ void mlir::torch::torch_to_linalg::
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenConstantPadNdOp>();
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
target.addIllegalOp<PseudoAtenFillScalarOp>();
patterns.add<ConvertPseudoAtenFillScalarOp>(typeConverter, context);
target.addIllegalOp<ValsemVariantAtenFillScalarOp>();
patterns.add<ConvertValsemVariantAtenFillScalarOp>(typeConverter, context);
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp, 0>>(typeConverter,
context);

View File

@ -36,7 +36,7 @@ static Value createElementwiseLinalgGeneric(
// what happens for a single result dimension. This loop not structured that
// way because it is hard to create the affine maps for each operand unless
// we structure the loop to iterate over tensor operands as the outer loop
// instead of inner loop. This pseudocode gives better intuition:
// instead of inner loop. This valsemcode gives better intuition:
// ```
// for each result dimension:
// for each tensor operand:

View File

@ -127,8 +127,8 @@ static Value createInitTensor(PatternRewriter &rewriter, Location loc,
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal,
/*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal);
return rewriter.create<PseudoAtenFillScalarOp>(loc, resultType, emptyTensor,
scalar);
return rewriter.create<ValsemVariantAtenFillScalarOp>(loc, resultType,
emptyTensor, scalar);
}
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
@ -951,7 +951,7 @@ public:
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value ub =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
rewriter.replaceOpWithNewOp<PseudoAtenUniformOp>(
rewriter.replaceOpWithNewOp<ValsemVariantAtenUniformOp>(
op, op.getType(), input, lb, ub, /*generator=*/none);
return success();
}
@ -1028,11 +1028,11 @@ public:
// aten.bernoulli.float(x, p) = (rand_like(float(x)) < tensor(p)).cast(type(x)).
// Since the input x can be an integer tensor, it's important to cast it to
// float type before passing it to the `aten.rand_like` op.
class DecomposePseudoAtenBernoulliFloatOp
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> {
class DecomposeValsemVariantAtenBernoulliFloatOp
: public OpRewritePattern<ValsemVariantAtenBernoulliFloatOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op,
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
@ -1060,11 +1060,11 @@ public:
// aten.bernoulli.Tensor(x, p) = (rand_like(float(x)) < p).cast(type(x)).
// Since the input x can be an integer tensor, it's important to cast it to
// float type before passing it to the `aten.rand_like` op.
class DecomposePseudoAtenBernoulliTensorOp
: public OpRewritePattern<PseudoAtenBernoulliTensorOp> {
class DecomposeValsemVariantAtenBernoulliTensorOp
: public OpRewritePattern<ValsemVariantAtenBernoulliTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PseudoAtenBernoulliTensorOp op,
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
@ -1208,7 +1208,7 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
Value constVal = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(fillVal));
// Initialize the allocated memory block with `fillVal`.
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>(
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
op, initTensor.getType(), initTensor, constVal);
return success();
}
@ -1388,7 +1388,7 @@ public:
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, op.getType(), op.size(), op.dtype(), op.layout(), op.device(),
op.pin_memory(), /*memory_format=*/noneVal);
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>(
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
op, op.getType(), emptyTensor, op.fill_value());
return success();
}
@ -1405,7 +1405,7 @@ public:
Value emptyTensor = rewriter.create<AtenEmptyLikeOp>(
op.getLoc(), op.getType(), op.self(), op.dtype(), op.layout(),
op.device(), op.pin_memory(), op.memory_format());
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>(
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
op, op.getType(), emptyTensor, op.fill_value());
return success();
}
@ -1491,10 +1491,10 @@ class DecomposeComplexOpsPass
target.addIllegalOp<Aten_UnsafeViewOp>();
patterns.add<DecomposeAtenBernoulliOp>(context);
target.addIllegalOp<AtenBernoulliOp>();
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
patterns.add<DecomposePseudoAtenBernoulliTensorOp>(context);
target.addIllegalOp<PseudoAtenBernoulliTensorOp>();
patterns.add<DecomposeValsemVariantAtenBernoulliFloatOp>(context);
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
patterns.add<DecomposeAtenRandLikeOp>(context);
target.addIllegalOp<AtenRandLikeOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);

View File

@ -147,17 +147,17 @@ public:
Location loc = op->getLoc();
Operation *newOp;
if (isa<AtenUniform_Op>(op)) {
newOp = rewriter.create<PseudoAtenUniformOp>(loc, op->getResultTypes(),
op->getOperands());
newOp = rewriter.create<ValsemVariantAtenUniformOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenBernoulli_FloatOp>(op)) {
newOp = rewriter.create<PseudoAtenBernoulliFloatOp>(
newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenBernoulli_TensorOp>(op)) {
newOp = rewriter.create<PseudoAtenBernoulliTensorOp>(
newOp = rewriter.create<ValsemVariantAtenBernoulliTensorOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenFill_ScalarOp>(op)) {
newOp = rewriter.create<PseudoAtenFillScalarOp>(loc, op->getResultTypes(),
op->getOperands());
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
loc, op->getResultTypes(), op->getOperands());
} else {
return failure();
}

View File

@ -505,10 +505,10 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
PseudoAtenBernoulliFloatOp, PseudoAtenBernoulliTensorOp,
PseudoAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenErfOp, AtenSiluOp, AtenHardtanhOp,
AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp,
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,

View File

@ -232,11 +232,12 @@ class ReifyShapeCalculationsPass
module.walk([&](Operation *op) {
Location loc = op->getLoc();
auto name = op->getName().stripDialect();
// For pseudo-ops (ops that are mechanically consistent with existing
// torch conventions, but simply not present, such as a missing in-place
// or out-of-place variant), remove the pseudo prefix.
if (name.startswith("pseudo."))
name = name.drop_front(strlen("pseudo."));
// For value-semantic variant ops, i.e. valsem-ops (ops that are
// mechanically consistent with existing torch conventions of in-place vs.
// out-of-place (value-semantic) variants), remove the prefix when
// looking them up in the shape library.
if (name.startswith("valsem."))
name = name.drop_front(strlen("valsem."));
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
auto shapeFunction =
shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName);

View File

@ -263,11 +263,14 @@ def check_shape_function(invocations: List[Invocation]):
def not_present_in_registry(f):
"""Decorator for shape functions not present in the shape registry.
This can happen for "pseudo" ops that we have in Torch-MLIR, such as
torch.aten.fill.Scalar, which are consistent with PyTorch conventions (e.g.
being the value-semantic correspondent of torch.aten.fill_.Scalar), but
that for whatever reason are not present in PyTorch. Such ops are useful
This can happen for "valsem" ops that we have in Torch-MLIR, such as
torch.valsem.aten.fill.Scalar, which are consistent with PyTorch conventions
(e.g. being the value-semantic correspondent of torch.aten.fill_.Scalar),
but that for whatever reason are not present in PyTorch. Such ops are useful
to keep certain passes within Torch-MLIR as consistent as possible.
For such ops, in the shape library generator, we treat them as if they
were registered torch ops (so we don't put "valsem" on them), to keep the
generator consistent.
To check if this decorator has been applied, use
`hasattr(f, "_not_present_in_registry")`.

View File

@ -388,7 +388,7 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
@ -405,7 +405,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.float
// CHECK-LABEL: func @torch.valsem.aten.bernoulli.float
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01
@ -419,7 +419,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
@ -428,16 +428,16 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none
%prob = torch.constant.float 4.000000e-01
%0 = torch.pseudo.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
%0 = torch.valsem.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.Tensor(
// CHECK-LABEL: func @torch.valsem.aten.bernoulli.Tensor(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>,
// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -450,7 +450,7 @@ func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
@ -459,9 +459,9 @@ func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.pseudo.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.pseudo.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%0 = torch.valsem.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
return %1 : !torch.vtensor
}
@ -496,13 +496,13 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[CST1_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[CST1_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[CST1_TENSOR]], %[[DIV]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[EMPTY_1:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[CST0_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[CST0_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32>
// CHECK: }
@ -524,7 +524,7 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FILL:.*]] = torch.pseudo.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
@ -543,13 +543,13 @@ func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.maximum %[[INPUT]], %[[MIN_TENSOR]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?],f32>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[VAL_10:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MAX_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[MAX_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?],f32>
func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> {
@ -616,7 +616,7 @@ func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.v
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MEM_FORMAT:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[RES:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
%float5.000000e00 = torch.constant.float 5.000000e+00
@ -639,7 +639,7 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INP]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[RES:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int5 = torch.constant.int 5

View File

@ -134,7 +134,7 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.uniform %[[T_VTENSOR]], %[[MIN]], %[[MAX]], %[[GENERATOR]] :
// CHECK: %[[VRET:.*]] = torch.valsem.aten.uniform %[[T_VTENSOR]], %[[MIN]], %[[MAX]], %[[GENERATOR]] :
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
@ -150,7 +150,7 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.valsem.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
@ -166,7 +166,7 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[VALUE:.*]] = torch.constant.int 1
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.valsem.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor

View File

@ -24,11 +24,11 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.fill.Scalar(
// CHECK-LABEL: func @pseudo_ops(
// CHECK-LABEL: func @valsem_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[VALUE:.*]] = torch.pseudo.aten.fill.Scalar %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[VALUE:.*]] = torch.valsem.aten.fill.Scalar %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[VALUE]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
@ -36,8 +36,8 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%0 = torch.pseudo.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor
func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%0 = torch.valsem.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor
return %0 : !torch.vtensor
}
@ -52,7 +52,7 @@ func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// CHECK-SAME: %[[ARG1:.*]]: !torch.float) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[UNIFORM:.*]] = torch.pseudo.aten.uniform %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[NONE]] : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[NONE]] : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[UNIFORM]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
@ -63,7 +63,7 @@ func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.pseudo.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
%0 = torch.valsem.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
return %0 : !torch.vtensor
}