diff --git a/src/main/java/com/caucho/hessian/io/JavaSerializer.java b/src/main/java/com/caucho/hessian/io/JavaSerializer.java
index 51035b3..091cea2 100644
--- a/src/main/java/com/caucho/hessian/io/JavaSerializer.java
+++ b/src/main/java/com/caucho/hessian/io/JavaSerializer.java
@@ -146,14 +146,16 @@ public void writeObject(Object obj, AbstractHessianOutput out)
try {
if (_writeReplace != null) {
Object repl = _writeReplace.invoke(obj, new Object[0]);
+ // for those writeReplaces that might return obj itself, no need to replace repl with obj
+ if (repl != obj) {
- out.removeRef(obj);
+ out.removeRef(obj);
- out.writeObject(repl);
+ out.writeObject(repl);
- out.replaceRef(repl, obj);
-
- return;
+ out.replaceRef(repl, obj);
+ return;
+ }
}
} catch (Exception e) {
log.log(Level.FINE, e.toString(), e);
diff --git a/src/test/java/com/caucho/hessian/io/WriteReplaceTest.java b/src/test/java/com/caucho/hessian/io/WriteReplaceTest.java
new file mode 100644
index 0000000..b478f8f
--- /dev/null
+++ b/src/test/java/com/caucho/hessian/io/WriteReplaceTest.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.caucho.hessian.io;
+
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.Serializable;
+
+/**
+ *
+ * @author junyuan
+ * @version WriteReplaceTest.java, v 0.1 2024-03-20 10:34 junyuan Exp $
+ */
+public class WriteReplaceTest {
+ private static SerializerFactory factory;
+ private static ByteArrayOutputStream os;
+
+ @BeforeClass
+ public static void setUp() {
+ factory = new SerializerFactory();
+ os = new ByteArrayOutputStream();
+ }
+
+ @Test
+ public void TestWriteReplace() throws IOException {
+ TestObject origin = new TestObject();
+ origin.setName("testWR");
+
+ os.reset();
+ Hessian2Output output = new Hessian2Output(os);
+
+ output.setSerializerFactory(factory);
+ try {
+ output.writeObject(origin);
+ } catch (Exception e) {
+ Assert.fail("should be no exception");
+ }
+ output.flush();
+
+ ByteArrayInputStream is = new ByteArrayInputStream(os.toByteArray());
+ Hessian2Input input = new Hessian2Input(is);
+ input.setSerializerFactory(factory);
+ TestObject actual = (TestObject) input.readObject();
+ Assert.assertEquals(actual.name, origin.name);
+ }
+
+ @Test
+ public void TestWrappedWriteReplace() throws IOException {
+ WrappedTestObject origin = new WrappedTestObject();
+ TestObject testObject = new TestObject();
+ testObject.setName("testWR");
+ origin.setTestObject(testObject);
+
+ os.reset();
+ Hessian2Output output = new Hessian2Output(os);
+
+ output.setSerializerFactory(factory);
+ try {
+ output.writeObject(origin);
+ } catch (Exception e) {
+ Assert.fail("should be no exception");
+ }
+ output.flush();
+
+ ByteArrayInputStream is = new ByteArrayInputStream(os.toByteArray());
+ Hessian2Input input = new Hessian2Input(is);
+ input.setSerializerFactory(factory);
+ WrappedTestObject actual = (WrappedTestObject) input.readObject();
+ Assert.assertEquals(actual.testObject.name, origin.testObject.name);
+ }
+
+ private static class WrappedTestObject implements Serializable {
+ private TestObject testObject;
+
+ /**
+ * Getter method for property testObject.
+ *
+ * @return property value of testObject
+ */
+ public TestObject getTestObject() {
+ return testObject;
+ }
+
+ /**
+ * Setter method for property testObject.
+ *
+ * @param testObject value to be assigned to property testObject
+ */
+ public void setTestObject(TestObject testObject) {
+ this.testObject = testObject;
+ }
+ }
+
+ private static class TestObject implements Serializable {
+ private static final long serialVersionUID = -452701306050912437L;
+
+ String name;
+
+ Object writeReplace() {
+ return this;
+ }
+
+ /**
+ * Getter method for property name.
+ *
+ * @return property value of name
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Setter method for property name.
+ *
+ * @param name value to be assigned to property name
+ */
+ public void setName(String name) {
+ this.name = name;
+ }
+ }
+}