mirror of https://github.com/llvm/torch-mlir
Improve "pseudo" op terminology.
The term "pseudo" is very vague and was getting confusing (I felt I had to explain it in every comment referencing it). Instead, rework the "pseudo" ops to instead be named: - MLIR Syntax: `torch.valsem.*` - C++ / ODS: `ValsemVariant*Op` This makes it clear what the concept is, and avoids confusion with other things that might be called "pseudo", since these are very specific and should be 100% consistently named w.r.t. the non-valsem-variant ops that they correspond to.pull/671/head
parent
7ea50a537a
commit
92da4988f0
|
@ -944,7 +944,7 @@ def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
|
|||
|
||||
// The corresponding without underscore variant for `torch.aten.uniform_`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
|
||||
def Torch_ValsemVariantAtenUniformOp: Torch_Op<"valsem.aten.uniform", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
|
@ -964,7 +964,7 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
|
|||
|
||||
// The corresponding without underscore variant for `torch.aten.bernoulli_.float`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
|
||||
def Torch_ValsemVariantAtenBernoulliFloatOp: Torch_Op<"valsem.aten.bernoulli.float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
|
@ -983,7 +983,7 @@ def Torch_PseudoAtenBernoulliFloatOp: Torch_Op<"pseudo.aten.bernoulli.float", [
|
|||
|
||||
// The corresponding without underscore variant for `torch.aten.bernoulli_.Tensor`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor", [
|
||||
def Torch_ValsemVariantAtenBernoulliTensorOp: Torch_Op<"valsem.aten.bernoulli.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
|
@ -1002,7 +1002,7 @@ def Torch_PseudoAtenBernoulliTensorOp: Torch_Op<"pseudo.aten.bernoulli.Tensor",
|
|||
|
||||
// The corresponding without underscore variant for `torch.aten.fill_.Scalar`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_PseudoAtenFillScalarOp: Torch_Op<"pseudo.aten.fill.Scalar", [
|
||||
def Torch_ValsemVariantAtenFillScalarOp: Torch_Op<"valsem.aten.fill.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
|
|
|
@ -250,7 +250,7 @@ def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
|||
}
|
||||
|
||||
def Torch_GeneratorType : Torch_Type<"Generator", "Generator"> {
|
||||
let summary = "Torch Generator for producing pseudo-random numbers";
|
||||
let summary = "Torch Generator for producing valsem-random numbers";
|
||||
}
|
||||
|
||||
def Torch_BoolType : Torch_Type<"Bool", "bool"> {
|
||||
|
|
|
@ -59,12 +59,12 @@ public:
|
|||
|
||||
|
||||
namespace {
|
||||
class ConvertPseudoAtenUniformOp
|
||||
: public OpConversionPattern<PseudoAtenUniformOp> {
|
||||
class ConvertValsemVariantAtenUniformOp
|
||||
: public OpConversionPattern<ValsemVariantAtenUniformOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PseudoAtenUniformOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(ValsemVariantAtenUniformOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
@ -162,6 +162,6 @@ void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
|
|||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
patterns.add<ConvertAtenDropoutOp>(typeConverter, context);
|
||||
target.addIllegalOp<PseudoAtenUniformOp>();
|
||||
patterns.add<ConvertPseudoAtenUniformOp>(typeConverter, context);
|
||||
target.addIllegalOp<ValsemVariantAtenUniformOp>();
|
||||
patterns.add<ConvertValsemVariantAtenUniformOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -80,12 +80,12 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertPseudoAtenFillScalarOp
|
||||
: public OpConversionPattern<PseudoAtenFillScalarOp> {
|
||||
class ConvertValsemVariantAtenFillScalarOp
|
||||
: public OpConversionPattern<ValsemVariantAtenFillScalarOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PseudoAtenFillScalarOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(ValsemVariantAtenFillScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
@ -318,8 +318,8 @@ void mlir::torch::torch_to_linalg::
|
|||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenConstantPadNdOp>();
|
||||
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
|
||||
target.addIllegalOp<PseudoAtenFillScalarOp>();
|
||||
patterns.add<ConvertPseudoAtenFillScalarOp>(typeConverter, context);
|
||||
target.addIllegalOp<ValsemVariantAtenFillScalarOp>();
|
||||
patterns.add<ConvertValsemVariantAtenFillScalarOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();
|
||||
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp, 0>>(typeConverter,
|
||||
context);
|
||||
|
|
|
@ -36,7 +36,7 @@ static Value createElementwiseLinalgGeneric(
|
|||
// what happens for a single result dimension. This loop not structured that
|
||||
// way because it is hard to create the affine maps for each operand unless
|
||||
// we structure the loop to iterate over tensor operands as the outer loop
|
||||
// instead of inner loop. This pseudocode gives better intuition:
|
||||
// instead of inner loop. This valsemcode gives better intuition:
|
||||
// ```
|
||||
// for each result dimension:
|
||||
// for each tensor operand:
|
||||
|
|
|
@ -127,8 +127,8 @@ static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
|||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal,
|
||||
/*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal);
|
||||
return rewriter.create<PseudoAtenFillScalarOp>(loc, resultType, emptyTensor,
|
||||
scalar);
|
||||
return rewriter.create<ValsemVariantAtenFillScalarOp>(loc, resultType,
|
||||
emptyTensor, scalar);
|
||||
}
|
||||
|
||||
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
||||
|
@ -951,7 +951,7 @@ public:
|
|||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
Value ub =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
rewriter.replaceOpWithNewOp<PseudoAtenUniformOp>(
|
||||
rewriter.replaceOpWithNewOp<ValsemVariantAtenUniformOp>(
|
||||
op, op.getType(), input, lb, ub, /*generator=*/none);
|
||||
return success();
|
||||
}
|
||||
|
@ -1028,11 +1028,11 @@ public:
|
|||
// aten.bernoulli.float(x, p) = (rand_like(float(x)) < tensor(p)).cast(type(x)).
|
||||
// Since the input x can be an integer tensor, it's important to cast it to
|
||||
// float type before passing it to the `aten.rand_like` op.
|
||||
class DecomposePseudoAtenBernoulliFloatOp
|
||||
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> {
|
||||
class DecomposeValsemVariantAtenBernoulliFloatOp
|
||||
: public OpRewritePattern<ValsemVariantAtenBernoulliFloatOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op,
|
||||
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
|
@ -1060,11 +1060,11 @@ public:
|
|||
// aten.bernoulli.Tensor(x, p) = (rand_like(float(x)) < p).cast(type(x)).
|
||||
// Since the input x can be an integer tensor, it's important to cast it to
|
||||
// float type before passing it to the `aten.rand_like` op.
|
||||
class DecomposePseudoAtenBernoulliTensorOp
|
||||
: public OpRewritePattern<PseudoAtenBernoulliTensorOp> {
|
||||
class DecomposeValsemVariantAtenBernoulliTensorOp
|
||||
: public OpRewritePattern<ValsemVariantAtenBernoulliTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PseudoAtenBernoulliTensorOp op,
|
||||
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
|
@ -1208,7 +1208,7 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
|
|||
Value constVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(fillVal));
|
||||
// Initialize the allocated memory block with `fillVal`.
|
||||
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>(
|
||||
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
|
||||
op, initTensor.getType(), initTensor, constVal);
|
||||
return success();
|
||||
}
|
||||
|
@ -1388,7 +1388,7 @@ public:
|
|||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, op.getType(), op.size(), op.dtype(), op.layout(), op.device(),
|
||||
op.pin_memory(), /*memory_format=*/noneVal);
|
||||
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>(
|
||||
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
|
||||
op, op.getType(), emptyTensor, op.fill_value());
|
||||
return success();
|
||||
}
|
||||
|
@ -1405,7 +1405,7 @@ public:
|
|||
Value emptyTensor = rewriter.create<AtenEmptyLikeOp>(
|
||||
op.getLoc(), op.getType(), op.self(), op.dtype(), op.layout(),
|
||||
op.device(), op.pin_memory(), op.memory_format());
|
||||
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>(
|
||||
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
|
||||
op, op.getType(), emptyTensor, op.fill_value());
|
||||
return success();
|
||||
}
|
||||
|
@ -1491,10 +1491,10 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<Aten_UnsafeViewOp>();
|
||||
patterns.add<DecomposeAtenBernoulliOp>(context);
|
||||
target.addIllegalOp<AtenBernoulliOp>();
|
||||
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
|
||||
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
|
||||
patterns.add<DecomposePseudoAtenBernoulliTensorOp>(context);
|
||||
target.addIllegalOp<PseudoAtenBernoulliTensorOp>();
|
||||
patterns.add<DecomposeValsemVariantAtenBernoulliFloatOp>(context);
|
||||
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
||||
patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
|
||||
target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
|
||||
patterns.add<DecomposeAtenRandLikeOp>(context);
|
||||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||
|
|
|
@ -147,17 +147,17 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
Operation *newOp;
|
||||
if (isa<AtenUniform_Op>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenUniformOp>(loc, op->getResultTypes(),
|
||||
op->getOperands());
|
||||
newOp = rewriter.create<ValsemVariantAtenUniformOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<AtenBernoulli_FloatOp>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenBernoulliFloatOp>(
|
||||
newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<AtenBernoulli_TensorOp>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenBernoulliTensorOp>(
|
||||
newOp = rewriter.create<ValsemVariantAtenBernoulliTensorOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<AtenFill_ScalarOp>(op)) {
|
||||
newOp = rewriter.create<PseudoAtenFillScalarOp>(loc, op->getResultTypes(),
|
||||
op->getOperands());
|
||||
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -505,10 +505,10 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
|
||||
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
|
||||
PseudoAtenBernoulliFloatOp, PseudoAtenBernoulliTensorOp,
|
||||
PseudoAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
|
||||
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
AtenHardswishOp, AtenErfOp, AtenSiluOp, AtenHardtanhOp,
|
||||
AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp,
|
||||
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
|
||||
|
|
|
@ -232,11 +232,12 @@ class ReifyShapeCalculationsPass
|
|||
module.walk([&](Operation *op) {
|
||||
Location loc = op->getLoc();
|
||||
auto name = op->getName().stripDialect();
|
||||
// For pseudo-ops (ops that are mechanically consistent with existing
|
||||
// torch conventions, but simply not present, such as a missing in-place
|
||||
// or out-of-place variant), remove the pseudo prefix.
|
||||
if (name.startswith("pseudo."))
|
||||
name = name.drop_front(strlen("pseudo."));
|
||||
// For value-semantic variant ops, i.e. valsem-ops (ops that are
|
||||
// mechanically consistent with existing torch conventions of in-place vs.
|
||||
// out-of-place (value-semantic) variants), remove the prefix when
|
||||
// looking them up in the shape library.
|
||||
if (name.startswith("valsem."))
|
||||
name = name.drop_front(strlen("valsem."));
|
||||
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
|
||||
auto shapeFunction =
|
||||
shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName);
|
||||
|
|
|
@ -263,11 +263,14 @@ def check_shape_function(invocations: List[Invocation]):
|
|||
def not_present_in_registry(f):
|
||||
"""Decorator for shape functions not present in the shape registry.
|
||||
|
||||
This can happen for "pseudo" ops that we have in Torch-MLIR, such as
|
||||
torch.aten.fill.Scalar, which are consistent with PyTorch conventions (e.g.
|
||||
being the value-semantic correspondent of torch.aten.fill_.Scalar), but
|
||||
that for whatever reason are not present in PyTorch. Such ops are useful
|
||||
This can happen for "valsem" ops that we have in Torch-MLIR, such as
|
||||
torch.valsem.aten.fill.Scalar, which are consistent with PyTorch conventions
|
||||
(e.g. being the value-semantic correspondent of torch.aten.fill_.Scalar),
|
||||
but that for whatever reason are not present in PyTorch. Such ops are useful
|
||||
to keep certain passes within Torch-MLIR as consistent as possible.
|
||||
For such ops, in the shape library generator, we treat them as if they
|
||||
were registered torch ops (so we don't put "valsem" on them), to keep the
|
||||
generator consistent.
|
||||
|
||||
To check if this decorator has been applied, use
|
||||
`hasattr(f, "_not_present_in_registry")`.
|
||||
|
|
|
@ -388,7 +388,7 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
|
|||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
|
@ -405,7 +405,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
|
|||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.float
|
||||
// CHECK-LABEL: func @torch.valsem.aten.bernoulli.float
|
||||
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01
|
||||
|
@ -419,7 +419,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
|
|||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
|
@ -428,16 +428,16 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
|
|||
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%prob = torch.constant.float 4.000000e-01
|
||||
%0 = torch.pseudo.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
%0 = torch.valsem.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.pseudo.aten.bernoulli.Tensor(
|
||||
// CHECK-LABEL: func @torch.valsem.aten.bernoulli.Tensor(
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>,
|
||||
// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
|
@ -450,7 +450,7 @@ func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
|
|||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
|
@ -459,9 +459,9 @@ func @torch.pseudo.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
|
|||
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func @torch.pseudo.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.pseudo.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
%0 = torch.valsem.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
@ -496,13 +496,13 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
|
|||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[CST1_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[CST1_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[CST1_TENSOR]], %[[DIV]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[EMPTY_1:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] :
|
||||
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[CST0_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[CST0_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
|
@ -524,7 +524,7 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
|||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FILL:.*]] = torch.pseudo.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -543,13 +543,13 @@ func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
|||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.maximum %[[INPUT]], %[[MIN_TENSOR]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_10:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MAX_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MAX_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[?],f32>
|
||||
func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> {
|
||||
|
@ -616,7 +616,7 @@ func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.v
|
|||
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[MEM_FORMAT:.*]] = torch.constant.none
|
||||
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: %[[RES:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
|
||||
func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
|
||||
%float5.000000e00 = torch.constant.float 5.000000e+00
|
||||
|
@ -639,7 +639,7 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
|
|||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INP]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[RES:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int5 = torch.constant.int 5
|
||||
|
|
|
@ -134,7 +134,7 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
|
|||
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
|
||||
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
|
||||
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.uniform %[[T_VTENSOR]], %[[MIN]], %[[MAX]], %[[GENERATOR]] :
|
||||
// CHECK: %[[VRET:.*]] = torch.valsem.aten.uniform %[[T_VTENSOR]], %[[MIN]], %[[MAX]], %[[GENERATOR]] :
|
||||
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
|
@ -150,7 +150,7 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
|
|||
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
|
||||
// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.valsem.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
|
@ -166,7 +166,7 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
|
|||
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[VALUE:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.valsem.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
|
|
|
@ -24,11 +24,11 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
|
|||
// CHECK: module {
|
||||
// CHECK: func private @__torch_mlir_shape_fn.aten.fill.Scalar(
|
||||
|
||||
// CHECK-LABEL: func @pseudo_ops(
|
||||
// CHECK-LABEL: func @valsem_ops(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor {
|
||||
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
|
||||
// CHECK: %[[VALUE:.*]] = torch.pseudo.aten.fill.Scalar %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
// CHECK: %[[VALUE:.*]] = torch.valsem.aten.fill.Scalar %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
// CHECK: torch.shape.calculate.yield %[[VALUE]] : !torch.vtensor
|
||||
// CHECK: } shapes {
|
||||
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
|
||||
|
@ -36,8 +36,8 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
|
|||
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
|
||||
// CHECK: } : !torch.vtensor
|
||||
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
|
||||
func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
|
||||
%0 = torch.pseudo.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
|
||||
%0 = torch.valsem.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
|
|||
// CHECK-SAME: %[[ARG1:.*]]: !torch.float) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
|
||||
// CHECK: %[[UNIFORM:.*]] = torch.pseudo.aten.uniform %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[NONE]] : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[NONE]] : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: torch.shape.calculate.yield %[[UNIFORM]] : !torch.vtensor
|
||||
// CHECK: } shapes {
|
||||
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
|
||||
|
@ -63,7 +63,7 @@ func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
|
|||
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
|
||||
func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.pseudo.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
|
||||
%0 = torch.valsem.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue