mirror of https://github.com/llvm/torch-mlir
Undo shape lib changes + update function signature of sum + zero (#1035)
This commit does three things: 1. Reverts some of the shape lib changes merged in https://github.com/llvm/torch-mlir/pull/844 2. Updates the signature of `aten.sum_dim_IntList` that was recently updated inpull/1036/head23bdb570cf
3. Replaces `aten.zero.functional` with `aten.zero`, updated in960758b0b7
parent
2d75654b2c
commit
11148e60d6
|
@ -2346,12 +2346,12 @@ def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
|
||||
def Torch_AtenZeroOp : Torch_Op<"aten.zero", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::zero : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
|
@ -2360,16 +2360,17 @@ def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
|
|||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
ParseResult AtenZeroOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
|
||||
void AtenZeroOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
|
||||
|
@ -5396,10 +5397,10 @@ def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$dim,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$keepdim,
|
||||
AnyTorchOptionalIntType:$dtype
|
||||
);
|
||||
|
|
|
@ -224,11 +224,11 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenZeroFunctionalOp
|
||||
: public OpRewritePattern<AtenZeroFunctionalOp> {
|
||||
class DecomposeAtenZeroOp
|
||||
: public OpRewritePattern<AtenZeroOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenZeroFunctionalOp op,
|
||||
LogicalResult matchAndRewrite(AtenZeroOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
|
@ -2272,8 +2272,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
||||
patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
|
||||
target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
|
||||
patterns.add<DecomposeAtenZeroFunctionalOp>(context);
|
||||
target.addIllegalOp<AtenZeroFunctionalOp>();
|
||||
patterns.add<DecomposeAtenZeroOp>(context);
|
||||
target.addIllegalOp<AtenZeroOp>();
|
||||
patterns.add<DecomposeAtenRandLikeOp>(context);
|
||||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||
|
|
|
@ -183,9 +183,6 @@ public:
|
|||
} else if (isa<AtenBernoulli_TensorOp>(op)) {
|
||||
newOp = rewriter.create<ValsemVariantAtenBernoulliTensorOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<AtenZero_Op>(op)) {
|
||||
newOp = rewriter.create<AtenZeroFunctionalOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<AtenFill_ScalarOp>(op)) {
|
||||
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
|
@ -273,7 +270,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
target.addIllegalOp<AtenUniform_Op>();
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.addIllegalOp<AtenBernoulli_TensorOp>();
|
||||
target.addIllegalOp<AtenZero_Op>();
|
||||
target.addIllegalOp<AtenFill_ScalarOp>();
|
||||
target.addIllegalOp<Aten_IndexPutImpl_Op>();
|
||||
target.addIllegalOp<AtenCopy_Op>();
|
||||
|
|
|
@ -649,7 +649,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
|
||||
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
||||
ValsemVariantAtenCopyOp, AtenZeroOp,
|
||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -519,9 +519,11 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
|
|||
def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
def aten〇sum〇dim_IntList(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
if dim is None:
|
||||
return upstream_shape_functions.mean_dim(self, [], keepdim, dtype)
|
||||
else:
|
||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.permute(self, dims)
|
||||
|
@ -717,13 +719,9 @@ def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Option
|
|||
def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
@not_present_in_registry
|
||||
def aten〇zero(self: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇zero〇functional(self: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
@not_present_in_registry
|
||||
def aten〇fill〇Scalar(self: List[int], value: float) -> List[int]:
|
||||
return self
|
||||
|
|
|
@ -288,7 +288,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
"aten::square : (Tensor) -> (Tensor)",
|
||||
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
||||
"aten::zero.functional : (Tensor) -> (Tensor)",
|
||||
"aten::zero : (Tensor) -> (Tensor)",
|
||||
]:
|
||||
emit_with_mutating_variants(key)
|
||||
# Elementwise tensor compute ops that don't have the standard mutating
|
||||
|
@ -441,7 +441,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
|
||||
emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)")
|
||||
emit("aten::max : (Tensor) -> (Tensor)")
|
||||
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
|
||||
|
|
|
@ -829,13 +829,13 @@ func.func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
|
|||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.zero.functional(
|
||||
// CHECK-LABEL: func.func @torch.aten.zero(
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[OUT:.*]] = torch.valsem.aten.fill.Scalar %[[INP]], %[[ZERO]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.zero.functional(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.zero.functional %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.zero %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue