mirror of https://github.com/llvm/torch-mlir
Update external/llvm-project
- Add `qualified` to ods because of https://reviews.llvm.org/D113873 and https://reviews.llvm.org/D116905 - Needed to revert https://github.com/llvm/torch-mlir/pull/520 as it was based on an old torch version. https://github.com/llvm/torch-mlir/pull/527 will bring this back with a better design. - Change ConvertAtenCatOp to use more accurate tensor shape info and as much static info as possible to pass `tensor.insert_slice` verification code added by https://reviews.llvm.org/D114715 - Other minor fixespull/526/head
parent
40efd2cb8e
commit
3745f54489
|
@ -1 +1 @@
|
|||
Subproject commit 966b72098363d44adf2882b9c34fcdbe344ff913
|
||||
Subproject commit 63f0c00d38ee7879239975a6743d4e6c7847b725
|
File diff suppressed because it is too large
Load Diff
|
@ -27,7 +27,7 @@ def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
|
|||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
|
||||
|
@ -42,7 +42,7 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
|
|||
let results = (outs
|
||||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)";
|
||||
let assemblyFormat = "$tup `,` $i attr-dict `:` qualified(type($tup)) `,` qualified(type($i)) `->` qualified(type($result))";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ def Torch_PrimDeviceOp : Torch_Op<"prim.device", [
|
|||
let results = (outs
|
||||
Torch_DeviceType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
|
||||
|
@ -71,7 +71,7 @@ def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
|
|||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
@ -85,7 +85,7 @@ def Torch_PrimTupleUnpackOp : Torch_Op<"prim.TupleUnpack", [
|
|||
let results = (outs
|
||||
Variadic<AnyTorchType>:$results
|
||||
);
|
||||
let assemblyFormat = "$tup attr-dict `:` type($tup) `->` type($results)";
|
||||
let assemblyFormat = "$tup attr-dict `:` qualified(type($tup)) `->` qualified(type($results))";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [
|
|||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
|
||||
|
@ -114,7 +114,7 @@ def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
|
|||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
|
||||
|
@ -129,7 +129,7 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
|
|||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
|
||||
|
@ -143,7 +143,7 @@ def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
|
|||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimMaxIntOp : Torch_Op<"prim.max.int", [
|
||||
|
@ -158,7 +158,7 @@ def Torch_PrimMaxIntOp : Torch_Op<"prim.max.int", [
|
|||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [
|
||||
|
@ -171,7 +171,7 @@ def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [
|
|||
);
|
||||
let results = (outs
|
||||
);
|
||||
let assemblyFormat = "$msg attr-dict `:` type($msg)";
|
||||
let assemblyFormat = "$msg attr-dict `:` qualified(type($msg))";
|
||||
}
|
||||
|
||||
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
|
||||
|
@ -184,7 +184,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
|
|||
let results = (outs
|
||||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = " attr-dict `:` type($result)";
|
||||
let assemblyFormat = " attr-dict `:` qualified(type($result))";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -199,7 +199,7 @@ def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
|
|||
let results = (outs
|
||||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = "$x attr-dict `:` type($x) `->` type($result)";
|
||||
let assemblyFormat = "$x attr-dict `:` qualified(type($x)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_PrimPrintOp : Torch_Op<"prim.Print", [
|
||||
|
@ -211,7 +211,7 @@ def Torch_PrimPrintOp : Torch_Op<"prim.Print", [
|
|||
);
|
||||
let results = (outs
|
||||
);
|
||||
let assemblyFormat = "`(` $operands `)` attr-dict `:` type($operands)";
|
||||
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands))";
|
||||
}
|
||||
|
||||
def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
|
||||
|
@ -224,6 +224,6 @@ def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
|
|||
let results = (outs
|
||||
Variadic<AnyTorchType>:$results
|
||||
);
|
||||
let assemblyFormat = "`(` $operands `)` attr-dict `:` type($operands) `->` type($results)";
|
||||
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))";
|
||||
}
|
||||
|
||||
|
|
|
@ -31,6 +31,6 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
|||
let results = (outs
|
||||
AnyTorchTensorType:$Y
|
||||
);
|
||||
let assemblyFormat = "$X `,` $W_prepack `,` $Y_scale_i `,` $Y_zero_point_i attr-dict `:` type($X) `,` type($W_prepack) `,` type($Y_scale_i) `,` type($Y_zero_point_i) `->` type($Y)";
|
||||
let assemblyFormat = "$X `,` $W_prepack `,` $Y_scale_i `,` $Y_zero_point_i attr-dict `:` qualified(type($X)) `,` qualified(type($W_prepack)) `,` qualified(type($Y_scale_i)) `,` qualified(type($Y_zero_point_i)) `->` qualified(type($Y))";
|
||||
}
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
|||
let regions = (region SizedRegion<1>:$region);
|
||||
let verifier = "return ::verify(*this);";
|
||||
|
||||
let assemblyFormat = "$region attr-dict `:` type($result)";
|
||||
let assemblyFormat = "$region attr-dict `:` qualified(type($result))";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
StringRef getClassName() { return getType().getClassName(); }
|
||||
|
@ -97,7 +97,7 @@ def Torch_SlotOp : Torch_Op<"slot", [
|
|||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$name `,` $value attr-dict `:` type($value)
|
||||
$name `,` $value attr-dict `:` qualified(type($value))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -272,7 +272,7 @@ def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
|||
// TODO: Have a SingleBlockExplicitTerminator trait.
|
||||
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
|
||||
|
||||
let assemblyFormat = "$initialValue attr-dict `:` type($initialValue)";
|
||||
let assemblyFormat = "$initialValue attr-dict `:` qualified(type($initialValue))";
|
||||
}
|
||||
|
||||
def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
|
||||
|
@ -284,7 +284,7 @@ def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
|
|||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$slot attr-dict `:` type($result)
|
||||
$slot attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -298,7 +298,7 @@ def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
|
|||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$slot `=` $value attr-dict `:` type($value)
|
||||
$slot `=` $value attr-dict `:` qualified(type($value))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -319,7 +319,7 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
|
|||
let results = (outs Variadic<AnyTorchType>:$results);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($results)
|
||||
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($results))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -343,7 +343,7 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
|||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$elements attr-dict `:` type($elements) `->` type($result)
|
||||
$elements attr-dict `:` qualified(type($elements)) `->` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -385,7 +385,7 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
|
|||
let verifier = "return ::verify(*this);";
|
||||
|
||||
let assemblyFormat = [{
|
||||
`keys` `(` ($keys^ `:` type($keys))? `)` `values` `(` ($values^ `:` type($values))? `)` attr-dict `->` type($result)
|
||||
`keys` `(` ($keys^ `:` qualified(type($keys)))? `)` `values` `(` ($values^ `:` qualified(type($values)))? `)` attr-dict `->` qualified(type($result))
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -401,7 +401,7 @@ def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
|
|||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$receiver `[` $name `]` attr-dict `:` type($receiver) `->` type($result)
|
||||
$receiver `[` $name `]` attr-dict `:` qualified(type($receiver)) `->` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -416,7 +416,7 @@ def Torch_PrimSetAttrOp : Torch_Op<"prim.SetAttr", []> {
|
|||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$receiver `[` $name `]` `=` $value attr-dict `:` type($receiver) `,` type($value)
|
||||
$receiver `[` $name `]` `=` $value attr-dict `:` qualified(type($receiver)) `,` qualified(type($value))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -431,7 +431,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
|
|||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($receiver) `,` functional-type($operands, $result)
|
||||
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` qualified(type($receiver)) `,` functional-type($operands, $result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -478,7 +478,7 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
|||
|
||||
let assemblyFormat = [{
|
||||
$shouldContinue `,`
|
||||
`iter` `(` ($iterArgs^ `:` type($iterArgs))? `)` attr-dict
|
||||
`iter` `(` ($iterArgs^ `:` qualified(type($iterArgs)))? `)` attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -525,7 +525,7 @@ def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
|
|||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict ($results^ `:` type($results))?
|
||||
attr-dict ($results^ `:` qualified(type($results)))?
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -662,7 +662,7 @@ def Torch_DerefineOp : Torch_Op<"derefine", [
|
|||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `to` type($result)
|
||||
$operand attr-dict `:` qualified(type($operand)) `to` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -698,7 +698,7 @@ def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
|
|||
let results = (outs Torch_LinearParamsType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$weight (`,` $bias^)? attr-dict `:` type($weight) (`,` type($bias)^)?
|
||||
$weight (`,` $bias^)? attr-dict `:` qualified(type($weight)) (`,` qualified(type($bias))^)?
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -727,7 +727,7 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
|||
|
||||
let assemblyFormat = [{
|
||||
$int_repr `,` $scale `,` $offset attr-dict
|
||||
`:` type($int_repr) `,` type($scale) `,` type($offset) `->` type($result)
|
||||
`:` qualified(type($int_repr)) `,` qualified(type($scale)) `,` qualified(type($offset)) `->` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -754,7 +754,7 @@ def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [
|
|||
let results = (outs Torch_NonValueTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $value `)` attr-dict `:` type($result)
|
||||
`(` $value `)` attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -786,7 +786,7 @@ def Torch_ValueTensorLiteralOp : Torch_Op<"vtensor.literal", [
|
|||
let results = (outs Torch_ValueTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $value `)` attr-dict `:` type($result)
|
||||
`(` $value `)` attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
@ -817,7 +817,7 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
|||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `to` type($result)
|
||||
$operand attr-dict `:` qualified(type($operand)) `to` qualified(type($result))
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
@ -849,7 +849,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
|
|||
Torch_NonValueTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($result)
|
||||
$operand attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
let verifier = "return ::verify(*this);";
|
||||
}
|
||||
|
@ -879,7 +879,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
|||
Torch_ValueTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($result)
|
||||
$operand attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
let verifier = "return ::verify(*this);";
|
||||
}
|
||||
|
@ -908,7 +908,7 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
|||
);
|
||||
let assemblyFormat = [{
|
||||
$value `overwrites` $overwritten attr-dict
|
||||
`:` type($value) `,` type($overwritten)
|
||||
`:` qualified(type($value)) `,` qualified(type($overwritten))
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
|
|||
AnyTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -58,7 +58,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
|
|||
Torch_ValueTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
|
@ -134,12 +134,7 @@ static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) {
|
|||
}
|
||||
|
||||
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
|
||||
if (auto tensorType = v.getType().cast<RankedTensorType>()) {
|
||||
if (!tensorType.isDynamicDim(dim))
|
||||
return b.create<arith::ConstantOp>(
|
||||
loc, b.getIndexAttr(tensorType.getShape()[dim]));
|
||||
}
|
||||
return b.create<tensor::DimOp>(loc, v, dim);
|
||||
return b.createOrFold<tensor::DimOp>(loc, v, dim);
|
||||
}
|
||||
|
||||
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
||||
|
@ -2202,8 +2197,8 @@ public:
|
|||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
|
||||
reassociation[0].push_back(i);
|
||||
input = rewriter.create<linalg::TensorCollapseShapeOp>(
|
||||
argmaxOp->getLoc(), input, reassociation);
|
||||
input = rewriter.create<tensor::CollapseShapeOp>(argmaxOp->getLoc(),
|
||||
input, reassociation);
|
||||
// Becomes 0 for flattened tensor.
|
||||
dim = 0;
|
||||
// Recast to fix shape.
|
||||
|
@ -2780,7 +2775,7 @@ public:
|
|||
if (!(startDim >= -1 && startDim <= 0 && endDim >= -1 && endDim <= 0))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "start_dim and end_dim must be in [-1, 0] when inputRank is 0");
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
||||
op, resultType, adaptor.self(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
@ -2797,7 +2792,7 @@ public:
|
|||
if (i < startDim || i >= endDim)
|
||||
j++;
|
||||
}
|
||||
Value collapsedTensor = rewriter.create<linalg::TensorCollapseShapeOp>(
|
||||
Value collapsedTensor = rewriter.create<tensor::CollapseShapeOp>(
|
||||
op->getLoc(), adaptor.self(), reassociation);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||
collapsedTensor);
|
||||
|
@ -3022,12 +3017,12 @@ public:
|
|||
Value result =
|
||||
isCollapse
|
||||
? rewriter
|
||||
.create<linalg::TensorCollapseShapeOp>(
|
||||
loc, adjustedResultType, castedInput, reassociation)
|
||||
.create<tensor::CollapseShapeOp>(loc, adjustedResultType,
|
||||
castedInput, reassociation)
|
||||
.result()
|
||||
: rewriter
|
||||
.create<linalg::TensorExpandShapeOp>(
|
||||
loc, adjustedResultType, castedInput, reassociation)
|
||||
.create<tensor::ExpandShapeOp>(loc, adjustedResultType,
|
||||
castedInput, reassociation)
|
||||
.result();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
return success();
|
||||
|
@ -3062,7 +3057,7 @@ public:
|
|||
// being unit extent, it will be collapsed to a 0-D tensor.
|
||||
if (resultRank == 0) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
||||
op, resultType, input, reassociation);
|
||||
return success();
|
||||
}
|
||||
|
@ -3113,7 +3108,7 @@ public:
|
|||
op, "expected output size mismatches with the result type rank");
|
||||
|
||||
if (isSqueezed) {
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
||||
op, resultType, input, reassociation);
|
||||
|
||||
} else {
|
||||
|
@ -3189,8 +3184,8 @@ public:
|
|||
// Note: In case the operand tensor type is of unit rank and is statically
|
||||
// shaped with unit dimension, the `reassociationMap` will be empty and the
|
||||
// input will be collapsed to a 0-D tensor.
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
||||
op, resultType, input, reassociationMap);
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, resultType, input,
|
||||
reassociationMap);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -3238,7 +3233,7 @@ public:
|
|||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
||||
op, resultType, adaptor.self(), reassociationMap);
|
||||
return success();
|
||||
}
|
||||
|
@ -3483,8 +3478,8 @@ public:
|
|||
if (i != dim)
|
||||
resultIdx++;
|
||||
}
|
||||
result = rewriter.create<linalg::TensorCollapseShapeOp>(loc, result,
|
||||
reassociation);
|
||||
result =
|
||||
rewriter.create<tensor::CollapseShapeOp>(loc, result, reassociation);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
|
||||
|
@ -3528,27 +3523,39 @@ public:
|
|||
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
|
||||
|
||||
for (int i = 0; i < rank; ++i)
|
||||
sizes.push_back(rewriter.create<tensor::DimOp>(loc, tensors[0], i));
|
||||
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, tensors[0], i));
|
||||
|
||||
// Calculate the size of the `dim` result dimension by adding the dim size
|
||||
// of each tensor together.
|
||||
Value resultDimSize = sizes[dim];
|
||||
Value dimIndex = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), adaptor.dim());
|
||||
|
||||
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
|
||||
loc, rewriter.getIndexAttr(dim));
|
||||
for (auto tensor : makeArrayRef(tensors).drop_front()) {
|
||||
auto size = rewriter.create<tensor::DimOp>(loc, tensor, dimIndex);
|
||||
resultDimSize = rewriter.create<arith::AddIOp>(loc, resultDimSize, size);
|
||||
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
|
||||
resultDimSize =
|
||||
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
|
||||
}
|
||||
sizes[dim] = resultDimSize;
|
||||
|
||||
auto toOpFoldResult = [](Value v) -> OpFoldResult {
|
||||
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
|
||||
if (!op)
|
||||
return v;
|
||||
return op.getValue();
|
||||
};
|
||||
|
||||
Value result = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, sizes, newResultType.getElementType());
|
||||
for (auto tensor : tensors) {
|
||||
sizes[dim] = rewriter.create<tensor::DimOp>(loc, tensor, dimIndex);
|
||||
result = rewriter.create<tensor::InsertSliceOp>(loc, tensor, result,
|
||||
offsets, sizes, strides);
|
||||
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, tensor);
|
||||
result = rewriter.createOrFold<tensor::InsertSliceOp>(
|
||||
loc, tensor, result,
|
||||
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
|
||||
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
|
||||
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
|
||||
offsets[dim] =
|
||||
rewriter.create<arith::AddIOp>(loc, offsets[dim], sizes[dim]);
|
||||
rewriter.createOrFold<arith::AddIOp>(loc, offsets[dim], sizes[dim]);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
|
||||
|
|
|
@ -53,8 +53,8 @@ public:
|
|||
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
|
||||
rewriter.eraseBlock(&dstRegion.back());
|
||||
};
|
||||
inlineIfCase(op.thenRegion(), scfIf.thenRegion());
|
||||
inlineIfCase(op.elseRegion(), scfIf.elseRegion());
|
||||
inlineIfCase(op.thenRegion(), scfIf.getThenRegion());
|
||||
inlineIfCase(op.elseRegion(), scfIf.getElseRegion());
|
||||
rewriter.replaceOp(op, scfIf.getResults());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
|
@ -36,7 +37,7 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rank = rewriter.create<RankOp>(op->getLoc(), adaptor.self());
|
||||
auto rank = rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.self());
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), rank);
|
||||
return success();
|
||||
|
@ -155,6 +156,7 @@ public:
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
|
@ -162,7 +164,7 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect, StandardOpsDialect,
|
||||
arith::ArithmeticDialect>();
|
||||
arith::ArithmeticDialect, tensor::TensorDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
|
|
@ -111,7 +111,7 @@ public:
|
|||
for (auto operand : llvm::enumerate(adaptor.getOperands())) {
|
||||
if (operand.value().getType().isa<Torch::NoneType>())
|
||||
continue;
|
||||
auto it = typeBoundMap.find({call.callee(), operand.index()});
|
||||
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
|
||||
if (it != typeBoundMap.end()) {
|
||||
if (auto valueTensorType = it->second.dyn_cast<ValueTensorType>()) {
|
||||
newOperands.push_back(copyTensorToType(
|
||||
|
@ -126,7 +126,7 @@ public:
|
|||
newOperands.push_back(operand.value());
|
||||
}
|
||||
|
||||
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.callee(),
|
||||
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.getCallee(),
|
||||
convertedResults, newOperands);
|
||||
int newOpResultIdx = 0;
|
||||
SmallVector<Value> newResults;
|
||||
|
|
|
@ -351,7 +351,7 @@ static LogicalResult analyzeInstances(FuncOp func,
|
|||
static FailureOr<Monomorphization>
|
||||
createMonomorphizationForCall(CallOp op, BlockAndValueMapping &mapping,
|
||||
SymbolTable &symbolTable) {
|
||||
auto func = symbolTable.lookup<FuncOp>(op.callee());
|
||||
auto func = symbolTable.lookup<FuncOp>(op.getCallee());
|
||||
Monomorphization monomorphization;
|
||||
monomorphization.func = func;
|
||||
for (auto operand : llvm::enumerate(op->getOperands())) {
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
|
|
@ -333,17 +333,17 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
|||
|
||||
if operator.is_vararg:
|
||||
assembly_operands = "`(` $operands `)`"
|
||||
assembly_operand_types = "type($operands)"
|
||||
assembly_operand_types = "qualified(type($operands))"
|
||||
else:
|
||||
assembly_operands = " `,` ".join("$" + arg["name"]
|
||||
for arg in operator.arguments)
|
||||
assembly_operand_types = " `,` ".join(
|
||||
f"""type(${arg["name"]})""" for arg in operator.arguments)
|
||||
f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments)
|
||||
if operator.is_varret:
|
||||
assembly_result_types = "type($results)"
|
||||
assembly_result_types = "qualified(type($results))"
|
||||
else:
|
||||
assembly_result_types = " `,` ".join(
|
||||
f"""type(${ret["name"] or generic_result_name(e)})"""
|
||||
f"""qualified(type(${ret["name"] or generic_result_name(e)}))"""
|
||||
for e, ret in enumerate(operator.returns))
|
||||
if assembly_operand_types and assembly_result_types:
|
||||
maybe_arrow = " `->` "
|
||||
|
|
|
@ -194,7 +194,6 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
return torchMlirTorchNoneTypeGet(context);
|
||||
}
|
||||
case TypeKind::AnyType: {
|
||||
auto anyType = torchType->cast<c10::AnyType>();
|
||||
return torchMlirTorchAnyTypeGet(context);
|
||||
}
|
||||
case TypeKind::ClassType: {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
|
||||
|
@ -22,7 +22,7 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],
|
|||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
|
||||
|
@ -39,7 +39,7 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
|
|||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_front(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -56,7 +56,7 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
|
|||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_back(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor<?x12xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
|
||||
|
@ -73,7 +73,7 @@ func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2]
|
|||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||
|
@ -18,7 +18,7 @@ func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtenso
|
|||
// CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||
func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||
|
@ -30,7 +30,7 @@ func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !tor
|
|||
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32>
|
||||
func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
|
||||
|
@ -42,7 +42,7 @@ func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>)
|
|||
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32>
|
||||
func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
|
||||
|
@ -54,7 +54,7 @@ func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>)
|
|||
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32>
|
||||
func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// CHECK-LABEL: func @torch.aten.dim(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32>
|
||||
// CHECK: %[[RANK:.*]] = rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
|
||||
// CHECK: %[[RANK:.*]] = tensor.rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
|
||||
// CHECK: %[[RANK_I64:.*]] = arith.index_cast %[[RANK]] : index to i64
|
||||
// CHECK: %[[RANK_TORCH_INT:.*]] = torch_c.from_i64 %[[RANK_I64]]
|
||||
// CHECK: return %[[RANK_TORCH_INT]] : !torch.int
|
||||
|
|
Loading…
Reference in New Issue