mirror of https://github.com/llvm/torch-mlir
Undo shape lib changes + update function signature of sum + zero (#1035)
This commit does three things: 1. Reverts some of the shape lib changes merged in https://github.com/llvm/torch-mlir/pull/844 2. Updates the signature of `aten.sum_dim_IntList` that was recently updated inpull/1036/head23bdb570cf
3. Replaces `aten.zero.functional` with `aten.zero`, updated in960758b0b7
parent
2d75654b2c
commit
11148e60d6
|
@ -2346,12 +2346,12 @@ def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
|
def Torch_AtenZeroOp : Torch_Op<"aten.zero", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
]> {
|
]> {
|
||||||
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
|
let summary = "Generated op for `aten::zero : (Tensor) -> (Tensor)`";
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
AnyTorchTensorType:$self
|
AnyTorchTensorType:$self
|
||||||
);
|
);
|
||||||
|
@ -2360,16 +2360,17 @@ def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
|
||||||
);
|
);
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
let extraClassDefinition = [{
|
let extraClassDefinition = [{
|
||||||
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
|
ParseResult AtenZeroOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
}
|
}
|
||||||
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
|
void AtenZeroOp::print(OpAsmPrinter &printer) {
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
|
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement
|
||||||
]> {
|
]> {
|
||||||
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
|
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
|
||||||
|
@ -5396,10 +5397,10 @@ def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
]> {
|
]> {
|
||||||
let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)`";
|
let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)`";
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
AnyTorchTensorType:$self,
|
AnyTorchTensorType:$self,
|
||||||
AnyTorchListOfTorchIntType:$dim,
|
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||||
Torch_BoolType:$keepdim,
|
Torch_BoolType:$keepdim,
|
||||||
AnyTorchOptionalIntType:$dtype
|
AnyTorchOptionalIntType:$dtype
|
||||||
);
|
);
|
||||||
|
|
|
@ -224,11 +224,11 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenZeroFunctionalOp
|
class DecomposeAtenZeroOp
|
||||||
: public OpRewritePattern<AtenZeroFunctionalOp> {
|
: public OpRewritePattern<AtenZeroOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenZeroFunctionalOp op,
|
LogicalResult matchAndRewrite(AtenZeroOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
|
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
|
||||||
rewriter.getI64IntegerAttr(0));
|
rewriter.getI64IntegerAttr(0));
|
||||||
|
@ -2272,8 +2272,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
||||||
patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
|
patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
|
||||||
target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
|
target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
|
||||||
patterns.add<DecomposeAtenZeroFunctionalOp>(context);
|
patterns.add<DecomposeAtenZeroOp>(context);
|
||||||
target.addIllegalOp<AtenZeroFunctionalOp>();
|
target.addIllegalOp<AtenZeroOp>();
|
||||||
patterns.add<DecomposeAtenRandLikeOp>(context);
|
patterns.add<DecomposeAtenRandLikeOp>(context);
|
||||||
target.addIllegalOp<AtenRandLikeOp>();
|
target.addIllegalOp<AtenRandLikeOp>();
|
||||||
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||||
|
|
|
@ -183,9 +183,6 @@ public:
|
||||||
} else if (isa<AtenBernoulli_TensorOp>(op)) {
|
} else if (isa<AtenBernoulli_TensorOp>(op)) {
|
||||||
newOp = rewriter.create<ValsemVariantAtenBernoulliTensorOp>(
|
newOp = rewriter.create<ValsemVariantAtenBernoulliTensorOp>(
|
||||||
loc, op->getResultTypes(), op->getOperands());
|
loc, op->getResultTypes(), op->getOperands());
|
||||||
} else if (isa<AtenZero_Op>(op)) {
|
|
||||||
newOp = rewriter.create<AtenZeroFunctionalOp>(
|
|
||||||
loc, op->getResultTypes(), op->getOperands());
|
|
||||||
} else if (isa<AtenFill_ScalarOp>(op)) {
|
} else if (isa<AtenFill_ScalarOp>(op)) {
|
||||||
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
|
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
|
||||||
loc, op->getResultTypes(), op->getOperands());
|
loc, op->getResultTypes(), op->getOperands());
|
||||||
|
@ -273,7 +270,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
target.addIllegalOp<AtenUniform_Op>();
|
target.addIllegalOp<AtenUniform_Op>();
|
||||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||||
target.addIllegalOp<AtenBernoulli_TensorOp>();
|
target.addIllegalOp<AtenBernoulli_TensorOp>();
|
||||||
target.addIllegalOp<AtenZero_Op>();
|
|
||||||
target.addIllegalOp<AtenFill_ScalarOp>();
|
target.addIllegalOp<AtenFill_ScalarOp>();
|
||||||
target.addIllegalOp<Aten_IndexPutImpl_Op>();
|
target.addIllegalOp<Aten_IndexPutImpl_Op>();
|
||||||
target.addIllegalOp<AtenCopy_Op>();
|
target.addIllegalOp<AtenCopy_Op>();
|
||||||
|
|
|
@ -649,7 +649,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
|
AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
|
||||||
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
||||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||||
ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
ValsemVariantAtenCopyOp, AtenZeroOp,
|
||||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
|
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -519,9 +519,11 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
|
||||||
def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
||||||
|
|
||||||
def aten〇sum〇dim_IntList(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
if dim is None:
|
||||||
|
return upstream_shape_functions.mean_dim(self, [], keepdim, dtype)
|
||||||
|
else:
|
||||||
|
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
||||||
|
|
||||||
def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
|
def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.permute(self, dims)
|
return upstream_shape_functions.permute(self, dims)
|
||||||
|
@ -717,13 +719,9 @@ def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Option
|
||||||
def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]:
|
def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
@not_present_in_registry
|
|
||||||
def aten〇zero(self: List[int]) -> List[int]:
|
def aten〇zero(self: List[int]) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aten〇zero〇functional(self: List[int]) -> List[int]:
|
|
||||||
return self
|
|
||||||
|
|
||||||
@not_present_in_registry
|
@not_present_in_registry
|
||||||
def aten〇fill〇Scalar(self: List[int], value: float) -> List[int]:
|
def aten〇fill〇Scalar(self: List[int], value: float) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -288,7 +288,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
"aten::square : (Tensor) -> (Tensor)",
|
"aten::square : (Tensor) -> (Tensor)",
|
||||||
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
||||||
"aten::zero.functional : (Tensor) -> (Tensor)",
|
"aten::zero : (Tensor) -> (Tensor)",
|
||||||
]:
|
]:
|
||||||
emit_with_mutating_variants(key)
|
emit_with_mutating_variants(key)
|
||||||
# Elementwise tensor compute ops that don't have the standard mutating
|
# Elementwise tensor compute ops that don't have the standard mutating
|
||||||
|
@ -441,7 +441,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
||||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||||
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
||||||
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
|
emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)")
|
||||||
emit("aten::max : (Tensor) -> (Tensor)")
|
emit("aten::max : (Tensor) -> (Tensor)")
|
||||||
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||||
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
|
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
|
||||||
|
|
|
@ -829,13 +829,13 @@ func.func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @torch.aten.zero.functional(
|
// CHECK-LABEL: func.func @torch.aten.zero(
|
||||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[OUT:.*]] = torch.valsem.aten.fill.Scalar %[[INP]], %[[ZERO]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[OUT:.*]] = torch.valsem.aten.fill.Scalar %[[INP]], %[[ZERO]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.zero.functional(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
%0 = torch.aten.zero.functional %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
%0 = torch.aten.zero %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
return %0 : !torch.vtensor<[?,?],f32>
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue