mirror of https://github.com/llvm/torch-mlir
Add ODS for group_norm
- Add ODS for native_group_norm/backward. - Add shape-inference for native_group_norm/backward . Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>pull/2120/head
parent
86718cb203
commit
40a2c501a1
|
@ -4656,6 +4656,38 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
AnyTorchOptionalTensorType:$weight,
|
||||||
|
AnyTorchOptionalTensorType:$bias,
|
||||||
|
Torch_IntType:$N,
|
||||||
|
Torch_IntType:$C,
|
||||||
|
Torch_IntType:$HxW,
|
||||||
|
Torch_IntType:$group,
|
||||||
|
Torch_FloatType:$eps
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result0,
|
||||||
|
AnyTorchTensorType:$result1,
|
||||||
|
AnyTorchTensorType:$result2
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 8, 3);
|
||||||
|
}
|
||||||
|
void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 8, 3);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -11066,6 +11098,40 @@ def Torch_AtenNativeBatchNormBackwardOp : Torch_Op<"aten.native_batch_norm_backw
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNativeGroupNormBackwardOp : Torch_Op<"aten.native_group_norm_backward", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$grad_out,
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
AnyTorchTensorType:$mean,
|
||||||
|
AnyTorchTensorType:$rstd,
|
||||||
|
AnyTorchOptionalTensorType:$weight,
|
||||||
|
Torch_IntType:$N,
|
||||||
|
Torch_IntType:$C,
|
||||||
|
Torch_IntType:$HxW,
|
||||||
|
Torch_IntType:$group,
|
||||||
|
AnyTorchListOfTorchBoolType:$output_mask
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result0,
|
||||||
|
AnyTorchTensorType:$result1,
|
||||||
|
AnyTorchTensorType:$result2
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNativeGroupNormBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 10, 3);
|
||||||
|
}
|
||||||
|
void AtenNativeGroupNormBackwardOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 10, 3);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward", [
|
def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -63,5 +63,59 @@ std::vector<torch::lazy::Shape> compute_shape_copy(
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape_native_group_norm(
|
||||||
|
const at::Tensor& input,
|
||||||
|
const c10::optional<at::Tensor>& weight,
|
||||||
|
const c10::optional<at::Tensor>& bias,
|
||||||
|
int64_t N, int64_t C, int64_t HxW,
|
||||||
|
int64_t group, double eps) {
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
input.sizes().size() >= 2,
|
||||||
|
"Input tensor must have at least batch and channel dimensions!");
|
||||||
|
std::vector<torch::lazy::Shape> shapes;
|
||||||
|
shapes.reserve(3);
|
||||||
|
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||||
|
|
||||||
|
// A separate mean and var needs to be kept for each group per N.
|
||||||
|
shapes.emplace_back(
|
||||||
|
at::get_default_dtype_as_scalartype(),
|
||||||
|
std::vector<int64_t>{N, group});
|
||||||
|
|
||||||
|
shapes.emplace_back(
|
||||||
|
at::get_default_dtype_as_scalartype(),
|
||||||
|
std::vector<int64_t>{N, group});
|
||||||
|
return shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
|
||||||
|
const at::Tensor& grad_out,
|
||||||
|
const at::Tensor& input,
|
||||||
|
const at::Tensor& mean,
|
||||||
|
const at::Tensor& rstd,
|
||||||
|
const c10::optional<at::Tensor>& weight,
|
||||||
|
int64_t N, int64_t C, int64_t HxW,
|
||||||
|
int64_t group, ::std::array<bool, 3> output_mask) {
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
input.sizes().size() >= 2,
|
||||||
|
"Input tensor must have at least batch and channel dimensions!");
|
||||||
|
std::vector<torch::lazy::Shape> shapes;
|
||||||
|
shapes.reserve(3);
|
||||||
|
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||||
|
|
||||||
|
int64_t num_features = input.size(1);
|
||||||
|
|
||||||
|
// `weight` and `bias` are vectors of length C (number of channels)`
|
||||||
|
shapes.emplace_back(
|
||||||
|
at::get_default_dtype_as_scalartype(),
|
||||||
|
std::vector<int64_t>{num_features});
|
||||||
|
shapes.emplace_back(
|
||||||
|
at::get_default_dtype_as_scalartype(),
|
||||||
|
std::vector<int64_t>{num_features});
|
||||||
|
|
||||||
|
return shapes;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -373,6 +373,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
|
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
emit(
|
||||||
|
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
|
||||||
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
@ -662,6 +665,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)")
|
emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||||
emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
||||||
emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)")
|
emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||||
|
emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||||
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
||||||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue